diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 5156e69b77af..39f2a208a5ec 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -84,7 +84,11 @@ ("gcc", [ ("9", [ ("3.8", [ - ("coverage", [XImportant(True)]), + ("coverage", [ + (True, [ + ("shard_test", [XImportant(True)]), + ]), + ]), ]), ]), ]), diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 0c03fac487d6..75b0e8812e1b 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -272,6 +272,7 @@ def instantiate_configs(): compiler_version = fc.find_prop("compiler_version") is_xla = fc.find_prop("is_xla") or False is_asan = fc.find_prop("is_asan") or False + is_coverage = fc.find_prop("is_coverage") or False is_onnx = fc.find_prop("is_onnx") or False is_pure_torch = fc.find_prop("is_pure_torch") or False is_vulkan = fc.find_prop("is_vulkan") or False @@ -311,6 +312,10 @@ def instantiate_configs(): python_version = fc.find_prop("pyver") parms_list[0] = fc.find_prop("abbreviated_pyver") + if is_coverage: + parms_list_ignored_for_docker_image.append("coverage") + python_version = fc.find_prop("pyver") + if is_onnx: parms_list.append("onnx") python_version = fc.find_prop("pyver") @@ -325,7 +330,6 @@ def instantiate_configs(): is_important = fc.find_prop("is_important") or False parallel_backend = fc.find_prop("parallel_backend") or None build_only = fc.find_prop("build_only") or False - is_coverage = fc.find_prop("is_coverage") or False shard_test = fc.find_prop("shard_test") or False # TODO: fix pure_torch python test packaging issue. if shard_test: @@ -333,9 +337,6 @@ def instantiate_configs(): restrict_phases.extend(["test1", "test2"]) if build_only or is_pure_torch: restrict_phases = ["build"] - if is_coverage and restrict_phases is None: - restrict_phases = ["build", "coverage_test"] - gpu_resource = None if cuda_version and cuda_version != "10": diff --git a/.circleci/config.yml b/.circleci/config.yml index 2ba5046c9034..2d367af02d3f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -655,9 +655,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test python3 -mpip install codecov python3 -mcodecov fi @@ -969,7 +971,7 @@ jobs: - run: name: Build - no_output_timeout: "1h" + no_output_timeout: "90m" command: | # Do not set -u here; there is some problem with CircleCI # variable expansion with PROMPT_COMMAND @@ -6877,16 +6879,23 @@ workflows: docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" resource_class: large - pytorch_linux_build: - name: pytorch_linux_bionic_py3_8_gcc9_build + name: pytorch_linux_bionic_py3_8_gcc9_coverage_build requires: - "docker-pytorch-linux-bionic-py3.8-gcc9" - build_environment: "pytorch-linux-bionic-py3.8-gcc9-build" + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" + - pytorch_linux_test: + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test1 + requires: + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test1" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" + resource_class: large - pytorch_linux_test: - name: pytorch_linux_bionic_py3_8_gcc9_coverage_test + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test2 requires: - - pytorch_linux_bionic_py3_8_gcc9_build - build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage_test" + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" resource_class: large - pytorch_linux_build: diff --git a/.circleci/docker/centos-rocm/Dockerfile b/.circleci/docker/centos-rocm/Dockerfile index 1bc7b0deea32..a94a7167a7f4 100644 --- a/.circleci/docker/centos-rocm/Dockerfile +++ b/.circleci/docker/centos-rocm/Dockerfile @@ -27,7 +27,7 @@ RUN rm install_glibc.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/docker/common/install_cache.sh b/.circleci/docker/common/install_cache.sh index 17931375b6f0..fc1630272472 100644 --- a/.circleci/docker/common/install_cache.sh +++ b/.circleci/docker/common/install_cache.sh @@ -2,6 +2,28 @@ set -ex +install_ubuntu() { + echo "Preparing to build sccache from source" + apt-get update + apt-get install -y cargo pkg-config libssl-dev + echo "Checking out sccache repo" + git clone https://github.com/pytorch/sccache + cd sccache + echo "Building sccache" + cargo build --release + cp target/release/sccache /opt/cache/bin + echo "Cleaning up" + cd .. + rm -rf sccache + apt-get remove -y cargo rustc + apt-get autoclean && apt-get clean +} + +install_binary() { + echo "Downloading sccache binary from S3 repo" + curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache +} + mkdir -p /opt/cache/bin mkdir -p /opt/cache/lib sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment @@ -11,12 +33,20 @@ export PATH="/opt/cache/bin:$PATH" if [ -n "$ROCM_VERSION" ]; then curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache else - curl --retry 3 https://s3.amazonaws.com/ossci-linux/sccache -o /opt/cache/bin/sccache + ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') + case "$ID" in + ubuntu) + install_ubuntu + ;; + *) + install_binary + ;; + esac fi chmod a+x /opt/cache/bin/sccache function write_sccache_stub() { - printf "#!/bin/sh\nexec sccache $(which $1) \"\$@\"" > "/opt/cache/bin/$1" + printf "#!/bin/sh\nif [ \$(ps -p \$PPID -o comm=) != sccache ]; then\n exec sccache $(which $1) \"\$@\"\nelse\n exec $(which $1) \"\$@\"\nfi" > "/opt/cache/bin/$1" chmod a+x "/opt/cache/bin/$1" } @@ -38,8 +68,8 @@ if [ -n "$CUDA_VERSION" ]; then # where CUDA is installed. Instead, we install an nvcc symlink outside # of the PATH, and set CUDA_NVCC_EXECUTABLE so that we make use of it. - printf "#!/bin/sh\nexec sccache $(which nvcc) \"\$@\"" > /opt/cache/lib/nvcc - chmod a+x /opt/cache/lib/nvcc + write_sccache_stub nvcc + mv /opt/cache/bin/nvcc /opt/cache/lib/ fi if [ -n "$ROCM_VERSION" ]; then diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index db8f1a457ecf..c63e28029f07 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -96,13 +96,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # TODO: This isn't working atm conda_install nnpack -c killeent - # Install some other packages + # Install some other packages, including those needed for Python test reporting # TODO: Why is scipy pinned # numba & llvmlite is pinned because of https://github.com/numba/numba/issues/4368 # scikit-learn is pinned because of # https://github.com/scikit-learn/scikit-learn/issues/14485 (affects gcc 5.5 # only) - as_jenkins pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0 + as_jenkins pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0 unittest-xml-reporting coverage popd fi diff --git a/.circleci/docker/common/install_gcc.sh b/.circleci/docker/common/install_gcc.sh index 48f17989f978..0e86df1c778c 100644 --- a/.circleci/docker/common/install_gcc.sh +++ b/.circleci/docker/common/install_gcc.sh @@ -15,6 +15,7 @@ if [ -n "$GCC_VERSION" ]; then update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 + update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.circleci/docker/common/install_lcov.sh b/.circleci/docker/common/install_lcov.sh index eea6e1a5bc21..b4364698318a 100644 --- a/.circleci/docker/common/install_lcov.sh +++ b/.circleci/docker/common/install_lcov.sh @@ -2,5 +2,7 @@ set -ex -sudo apt-get -qq update -sudo apt-get -qq install lcov +git clone --branch v1.15 https://github.com/linux-test-project/lcov.git +pushd lcov +sudo make install # will be installed in /usr/local/bin/lcov +popd diff --git a/.circleci/docker/ubuntu-cuda/Dockerfile b/.circleci/docker/ubuntu-cuda/Dockerfile index d3a9027d5f06..b767440066e7 100644 --- a/.circleci/docker/ubuntu-cuda/Dockerfile +++ b/.circleci/docker/ubuntu-cuda/Dockerfile @@ -24,7 +24,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/docker/ubuntu-rocm/Dockerfile b/.circleci/docker/ubuntu-rocm/Dockerfile index 5fd133d08245..761bf0438d7f 100644 --- a/.circleci/docker/ubuntu-rocm/Dockerfile +++ b/.circleci/docker/ubuntu-rocm/Dockerfile @@ -21,7 +21,7 @@ RUN bash ./install_clang.sh && rm install_clang.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/docker/ubuntu/Dockerfile b/.circleci/docker/ubuntu/Dockerfile index ca4d3c58dbc6..3938d99ade0b 100644 --- a/.circleci/docker/ubuntu/Dockerfile +++ b/.circleci/docker/ubuntu/Dockerfile @@ -33,7 +33,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/verbatim-sources/job-specs/binary-job-specs.yml b/.circleci/verbatim-sources/job-specs/binary-job-specs.yml index 489dfefdbff1..f8d1dde4e5ad 100644 --- a/.circleci/verbatim-sources/job-specs/binary-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/binary-job-specs.yml @@ -174,7 +174,7 @@ - run: name: Build - no_output_timeout: "1h" + no_output_timeout: "90m" command: | # Do not set -u here; there is some problem with CircleCI # variable expansion with PROMPT_COMMAND diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 868f32fd49fa..f6f37dbb0470 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -217,9 +217,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test python3 -mpip install codecov python3 -mcodecov fi diff --git a/.flake8 b/.flake8 index 7ecc6df31754..8be8496e4224 100644 --- a/.flake8 +++ b/.flake8 @@ -12,5 +12,5 @@ ignore = B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -per-file-ignores = __init__.py: F401 +per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi,.git,build,build_test_custom_build,build_code_analyzer diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 18f7f0a1783f..04abbc0275af 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -42,6 +42,10 @@ jobs: run: | sudo apt-get install -y doxygen && pip install -r requirements.txt cd docs/cpp/source && ./check-doxygen.sh + - name: CUDA kernel launch check + run: | + set -eux + python torch/testing/check_kernel_launches.py |& tee ${GITHUB_WORKSPACE}/cuda_kernel_launch_checks.txt flake8-py3: runs-on: ubuntu-latest @@ -65,11 +69,10 @@ jobs: id: get_pr_tip - name: Run flake8 run: | - set -eux + set -eux -o pipefail pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 flake8 --version - flake8 > ${GITHUB_WORKSPACE}/flake8-output.txt - cat ${GITHUB_WORKSPACE}/flake8-output.txt + flake8 | tee ${GITHUB_WORKSPACE}/flake8-output.txt - name: Add annotations uses: pytorch/add-annotations-github-action@master with: diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 03583f3805c7..1b26a5980d05 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -171,7 +171,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # default pip version is too old(9.0.2), unable to support tag `manylinux2010`. # Fix the pip error: Couldn't find a version that satisfies the requirement pip install --upgrade pip - pip install -q --user ort-nightly==1.5.0.dev202009182 + pip install -q --user onnxruntime==1.5.2 fi "$ROOT_DIR/scripts/onnx/test.sh" fi diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index 3e197d867b6e..b94e797e7010 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -42,6 +42,11 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then nvcc --version fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + # enable build option in CMake + export USE_CPP_CODE_COVERAGE=ON +fi + # TODO: Don't run this... pip_install -r requirements.txt || true diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh new file mode 100755 index 000000000000..3b75999ceb2e --- /dev/null +++ b/.jenkins/pytorch/codegen-test.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +# This script can also be used to test whether your diff changes any codegen output. +# +# Run it before and after your change: +# .jenkins/pytorch/codegen-test.sh +# .jenkins/pytorch/codegen-test.sh +# +# Then run diff to compare the generated files: +# diff -Naur + +set -eu -o pipefail + +if [ "$#" -eq 0 ]; then + COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" + source "$(dirname "${BASH_SOURCE[0]}")/common.sh" + OUT="$(dirname "${BASH_SOURCE[0]}")/../../codegen_result" +else + OUT=$1 +fi + +set -x + +rm -rf "$OUT" + +# aten codegen +python -m tools.codegen.gen \ + -d "$OUT"/torch/share/ATen + +# torch codegen +python -m tools.setup_helpers.generate_code \ + --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --install_dir "$OUT" + +# pyi codegen +mkdir -p "$OUT"/pyi/torch/_C +mkdir -p "$OUT"/pyi/torch/nn +python -m tools.pyi.gen_pyi \ + --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --out "$OUT"/pyi + +# autograd codegen (called by torch codegen but can run independently) +python -m tools.autograd.gen_autograd \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/autograd \ + tools/autograd + +# unboxing_wrappers codegen (called by torch codegen but can run independently) +mkdir -p "$OUT"/unboxing_wrappers +python -m tools.jit.gen_unboxing_wrappers \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/unboxing_wrappers \ + tools/jit/templates + +# annotated_fn_args codegen (called by torch codegen but can run independently) +mkdir -p "$OUT"/annotated_fn_args +python -m tools.autograd.gen_annotated_fn_args \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/annotated_fn_args \ + tools/autograd diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 1af459ab8cc8..88bcfc93e19d 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -11,17 +11,13 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" -if [ -n "${IN_CI}" ]; then - # TODO move this to docker - pip_install unittest-xml-reporting coverage pytest +if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then + export PYTORCH_TEST_WITH_SLOW=1 + export PYTORCH_TEST_SKIP_FAST=1 +fi - if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then - export PYTORCH_TEST_WITH_SLOW=1 - export PYTORCH_TEST_SKIP_FAST=1 - fi - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - export PYTORCH_COLLECT_COVERAGE=1 - fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + export PYTORCH_COLLECT_COVERAGE=1 fi if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then @@ -401,10 +397,15 @@ else test_distributed test_benchmarks test_rpc - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - pushd test - echo "Generating XML coverage report" - time python -mcoverage xml - popd - fi +fi + +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + pushd test + echo "Generating XML coverage report" + time python -mcoverage xml + popd + pushd build + echo "Generating lcov coverage report for C++ sources" + time lcov --capture --directory . --output-file coverage.info + popd fi diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index 34c3698a1307..1e3cfe090abf 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -39,7 +39,7 @@ if %errorlevel% neq 0 ( exit /b %errorlevel% ) popd :: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 -pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest +pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest coverage if %errorlevel% neq 0 ( exit /b %errorlevel% ) :: No need to install faulthandler since we only test Python >= 3.6 on Windows :: faulthandler is builtin since Python 3.3 diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index abcd5756d747..adf9b4c82620 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -14,6 +14,10 @@ fi export TMP_DIR="${PWD}/build/win_tmp" export TMP_DIR_WIN=$(cygpath -w "${TMP_DIR}") +export PROJECT_DIR="${PWD}" +export PROJECT_DIR_WIN=$(cygpath -w "${PROJECT_DIR}") +export TEST_DIR="${PWD}/test" +export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}") export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results" export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") @@ -45,6 +49,7 @@ run_tests() { $SCRIPT_HELPERS_DIR/test_libtorch.bat else if [[ "${JOB_BASE_NAME}" == *-test1 ]]; then + export PYTORCH_COLLECT_COVERAGE=1 $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" && \ $SCRIPT_HELPERS_DIR/test_libtorch.bat if [[ "${USE_CUDA}" == "1" ]]; then @@ -59,3 +64,16 @@ run_tests() { } run_tests && assert_git_not_dirty && echo "TEST PASSED" + +if [[ "${BUILD_ENVIRONMENT}" == "pytorch-win-vs2019-cuda10-cudnn7-py3" ]] && [[ "${JOB_BASE_NAME}" == *-test1 ]]; then + pushd $TEST_DIR + python -mpip install coverage + echo "Generating XML coverage report" + time python -mcoverage xml + popd + + pushd $PROJECT_DIR + python -mpip install codecov + python -mcodecov + popd +fi diff --git a/BUILD.bazel b/BUILD.bazel index 9eced9b2c563..4ec99d770f70 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -126,18 +126,13 @@ genrule( outs = [ "aten/src/ATen/Declarations.yaml", "aten/src/ATen/BackendSelectRegister.cpp", - "aten/src/ATen/CPUType.h", "aten/src/ATen/CPUType.cpp", "aten/src/ATen/Functions.h", "aten/src/ATen/Functions.cpp", "aten/src/ATen/NativeFunctions.h", - "aten/src/ATen/MkldnnCPUType.h", "aten/src/ATen/MkldnnCPUType.cpp", - "aten/src/ATen/QuantizedCPUType.h", "aten/src/ATen/QuantizedCPUType.cpp", - "aten/src/ATen/SparseCPUType.h", "aten/src/ATen/SparseCPUType.cpp", - "aten/src/ATen/TypeDefault.h", "aten/src/ATen/TypeDefault.cpp", "aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorMethods.cpp", diff --git a/NOTICE b/NOTICE index 020beaea4c46..5abaac479a75 100644 --- a/NOTICE +++ b/NOTICE @@ -284,6 +284,112 @@ Apache License Version 2.0: incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 51f231ee922b..9b39006b106e 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -156,11 +156,12 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef( first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); // Strategy: For each batch, we are going to push slices (where applicable) // of the arguments onto `stack`, and call `op`. @@ -288,11 +289,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta auto num_batch_dims = input_physical_views.front().numBatchDims(); auto some_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef(some_sizes.begin(), some_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); + // Without a shape-checking API, we're unable to compute the correct shape of + // the output so we just error out. + TORCH_CHECK(num_batches > 0, + "Batching rule not implemented for ", schema.operator_name(), ". ", + "The fallback path does not support vmap over dims of size 0."); // Strategy: For each batch, we are going to push slices (where applicable) // of the arguments onto `stack`, call `op`, and store the result in diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 029a5be521f7..0b180b5059d1 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -222,7 +222,7 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { VmapDimVector all_dims_physical; all_dims_physical.reserve(self_physical.tensor().dim()); for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) { - all_dims_physical.push_back(bdim); + all_dims_physical.push_back(bdim); } all_dims_physical.insert( all_dims_physical.end(), diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index fed5e88e5314..0597f993b608 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -314,10 +314,12 @@ static inline void manual_seed(uint64_t seed) { } // NB: Sometimes we build with CUDA, but we don't have any GPUs // available. In that case, we must not seed CUDA; it will fail! - int num_gpus = detail::getCUDAHooks().getNumGPUs(); + const auto num_gpus = detail::getCUDAHooks().getNumGPUs(); if (hasCUDA() && num_gpus > 0) { for (int i = 0; i < num_gpus; i++) { - auto cuda_gen = globalContext().defaultGenerator(Device(at::kCUDA, i)); + auto cuda_gen = globalContext().defaultGenerator( + Device(at::kCUDA, static_cast(i)) + ); { // See Note [Acquire lock when using random generators] std::lock_guard lock(cuda_gen.mutex()); diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index e0fc25c394d3..e0d8a9a525b3 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -31,6 +31,8 @@ const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \ const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \ toUnderlying(enum_type); \ + (void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \ + /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \ return __VA_ARGS__(); \ } diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index f0a55470cc1c..60db3d9a9d4e 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 1bc9b66777bc..b24a83acdd5d 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index 20717ad43e6f..7b3be6db3d77 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index a4007c3115dc..b00f80d232db 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -21,7 +21,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { // public constructor for now... OpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp new file mode 100644 index 000000000000..7a794cb5c312 --- /dev/null +++ b/aten/src/ATen/ScalarOps.cpp @@ -0,0 +1,40 @@ +// FastPass +#ifdef _MSC_VER +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#endif + +#include +#include +#include + +namespace at { +namespace { +template +inline void fill_inplace(Tensor& self, Scalar value_scalar) { + auto value = value_scalar.to(); + scalar_t* dptr = static_cast(self.data_ptr()); + *dptr = value; +} +} + +namespace detail { +Tensor& scalar_fill(Tensor& self, Scalar value) { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { + fill_inplace(self, value); + }); + return self; +} + +Tensor scalar_tensor_static(Scalar s, const TensorOptions& options) { + at::tracer::impl::NoTracerDispatchMode tracer_guard; + at::AutoNonVariableTypeMode non_var_type_mode(true); + auto result = at::detail::empty_cpu({}, options); + scalar_fill(result, s); + return result; +} +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/ScalarOps.h b/aten/src/ATen/ScalarOps.h index 8c07a9d618bc..60cee3ea284b 100644 --- a/aten/src/ATen/ScalarOps.h +++ b/aten/src/ATen/ScalarOps.h @@ -4,6 +4,18 @@ #include #include +namespace at { +namespace detail { +// When filling a number to 1-element CPU tensor, we want to skip +// everything but manipulate data ptr directly. +// Ideally this fast pass should be implemented in TensorIterator, +// but we also want to skip compute_types which in not avoidable +// in TensorIterator for now. +Tensor& scalar_fill(Tensor& self, Scalar value); +TORCH_API Tensor scalar_tensor_static(Scalar s, const TensorOptions& options); +} // namespace detail +} // namespace at + // This is in the c10 namespace because we use ADL to find the functions in it. namespace c10 { @@ -11,16 +23,14 @@ namespace c10 { // to implement this without going through Derived Types (which are not part of core). inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { // This is the fast track we have for CPU scalar tensors. - if (device == at::kCPU) { + if (device == at::kCPU && !s.isComplex()) { if (s.isFloatingPoint()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kDouble)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kDouble)); } else if (s.isBoolean()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kBool)); - } else if (s.isComplex()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kComplexDouble)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kBool)); } else { AT_ASSERT(s.isIntegral(false)); - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kLong)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kLong)); } } if (s.isFloatingPoint()) { diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 3119c81ac8aa..45492d7b212e 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -30,12 +30,12 @@ namespace { // // This means that we allocate a [1,0] size indices tensor and a [0] size // values tensor for such an empty tensor. -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type) : SparseTensorImpl(key_set, data_type , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long)) , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {} -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, at::Tensor indices, at::Tensor values) : TensorImpl(key_set, data_type, values.device()) , sparse_dim_(1) , dense_dim_(0) diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index bdccb540734f..b8e6bb26bf7f 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -31,7 +31,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { public: // Public for now... - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); int64_t nnz() const { return values_.size(0); } int64_t sparse_dim() const { return sparse_dim_; } @@ -217,7 +217,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { refresh_numel(); } private: - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta, at::Tensor indices, at::Tensor values); /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 626e0c73e45e..08588f6a8cdd 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -335,8 +335,7 @@ c10::optional> computeStride( // we use the stride as if it were computed via resize. // This could perhaps be combined with the below code, but the complexity // didn't seem worth it. - int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, - std::multiplies()); + const int64_t numel = prod_intlist(oldshape); if (numel == 0 && oldshape.equals(newshape)) { return oldstride.vec(); } diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/TypeDefault.h similarity index 86% rename from aten/src/ATen/templates/TypeDefault.h rename to aten/src/ATen/TypeDefault.h index fb62c7ba6354..7b5d77ba4d22 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/TypeDefault.h @@ -1,7 +1,5 @@ #pragma once -// ${generated_comment} - #include #include #include @@ -29,8 +27,4 @@ struct Quantizer; // to frontend using ConstQuantizerPtr = const c10::intrusive_ptr&; -namespace TypeDefault { - ${type_method_declarations} -} // namespace TypeDefault - } // namespace at diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index ccd4e4ba9f2f..8a4fa37e469e 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace at { @@ -12,4 +14,52 @@ int _crash_if_asan(int arg) { return x[0]; } +namespace detail { +// empty_cpu is used in ScalarOps.h, which can be referenced by other ATen files. Since we want to decouple direct referencing native symbols and only access native symbols through dispatching, we move its implementation here. +Tensor empty_cpu( + IntArrayRef size, + const TensorOptions& options, + c10::optional optional_memory_format) { + TORCH_CHECK( + !(options.has_memory_format() && optional_memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + const MemoryFormat memory_format = + optional_memory_format.value_or( + options.memory_format_opt().value_or( + MemoryFormat::Contiguous)); + + AT_ASSERT(options.device().type() == DeviceType::CPU); + check_size_nonnegative(size); + + c10::Allocator* allocator; + if (options.pinned_memory()) { + allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); + } else { + allocator = at::getCPUAllocator(); + } + + int64_t nelements = prod_intlist(size); + const caffe2::TypeMeta dtype = options.dtype(); + const int64_t size_bytes = nelements * dtype.itemsize(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator->allocate(size_bytes), + allocator, + /*resizeable=*/true); + + auto tensor = detail::make_tensor( + std::move(storage_impl), at::DispatchKey::CPU, dtype); + // Default TensorImpl has size [0] + if (size.size() != 1 || size[0] != 0) { + tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); + } + + tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + + return tensor; +} +} // namespace detail + } // at diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index df0e49920afa..4fe4b632362b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -93,10 +93,18 @@ inline int64_t sum_intlist(ArrayRef list) { return std::accumulate(list.begin(), list.end(), 0ll); } -inline int64_t prod_intlist(ArrayRef list) { - return std::accumulate(list.begin(), list.end(), 1ll, std::multiplies()); +//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow +//the next 2 functions set `init` type to int64_t to avoid overflow. +template::value, int>::type = 0> +inline int64_t prod_intlist(const C &container){ + return std::accumulate(container.begin(), container.end(), static_cast(1), std::multiplies()); } +template::value_type>::value, int>::type = 0> +inline int64_t prod_intlist(Iter begin, Iter end){ + return std::accumulate(begin, end, static_cast(1), std::multiplies()); +} /** * Utility function to static cast input Generator* to * the backend generator type (CPU/CUDAGeneratorImpl etc.) @@ -120,4 +128,18 @@ static inline T* get_generator_or_default(const c10::optional& gen, c return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); } +inline void check_size_nonnegative(IntArrayRef size) { + for (auto x: size) { + TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); + } +} + +namespace detail { +CAFFE2_API +Tensor empty_cpu( + IntArrayRef size, + const TensorOptions& options = {}, + c10::optional memory_format = c10::nullopt); +} // namespace detail + } // at diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h index f5c997a55b74..6cfbea286f00 100644 --- a/aten/src/ATen/core/Dict_inl.h +++ b/aten/src/ATen/core/Dict_inl.h @@ -38,7 +38,7 @@ namespace detail { inline size_t DictKeyHash::operator()(const IValue& ivalue) const { if (ivalue.isInt()) { - return std::hash()(ivalue.toInt()); + return std::hash()(ivalue.toInt()); } else if (ivalue.isString()) { return std::hash()(ivalue.toStringRef()); } else if (ivalue.isDouble()) { diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 33e4ebcfc7dc..187b217604ba 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -210,6 +210,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("i0", CppFunction::makeFallthrough()); m.impl("i0.out", CppFunction::makeFallthrough()); m.impl("i0_", CppFunction::makeFallthrough()); + m.impl("igamma", CppFunction::makeFallthrough()); + m.impl("igamma.out", CppFunction::makeFallthrough()); + m.impl("igamma_", CppFunction::makeFallthrough()); m.impl("imag", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index da259e82990a..e84ad93de37d 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -371,6 +371,8 @@ _(aten, hstack) \ _(aten, hypot) \ _(aten, i0) \ _(aten, i0_) \ +_(aten, igamma) \ +_(aten, igamma_) \ _(aten, ifft) \ _(aten, index) \ _(aten, index_add) \ @@ -736,7 +738,6 @@ _(aten, vander) \ _(aten, var) \ _(aten, view) \ _(aten, view_as) \ -_(aten, vstack) \ _(aten, where) \ _(aten, zero) \ _(aten, zeros) \ @@ -781,6 +782,7 @@ _(attr, ceil_mode) \ _(attr, checked_signal_sizes) \ _(attr, chunks) \ _(attr, columns) \ +_(attr, column_stack) \ _(attr, complex_input) \ _(attr, complex_output) \ _(attr, condition) \ diff --git a/aten/src/ATen/core/blob.h b/aten/src/ATen/core/blob.h index 988e99b2395e..3b6bafa12e62 100644 --- a/aten/src/ATen/core/blob.h +++ b/aten/src/ATen/core/blob.h @@ -51,7 +51,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { /** * Returns the meta info of the blob. */ - const TypeMeta& meta() const noexcept { + const TypeMeta meta() const noexcept { return meta_; } @@ -155,7 +155,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { TypeMeta::Make::type>())); } - void* ShareExternal(void* allocated, const TypeMeta& meta) { + void* ShareExternal(void* allocated, const TypeMeta meta) { free_(); meta_ = meta; pointer_ = allocated; diff --git a/aten/src/ATen/core/boxing/impl/boxing.h b/aten/src/ATen/core/boxing/impl/boxing.h index d40823555c65..484d462b8ad9 100644 --- a/aten/src/ATen/core/boxing/impl/boxing.h +++ b/aten/src/ATen/core/boxing/impl/boxing.h @@ -97,7 +97,13 @@ using can_unbox = // template struct BoxedKernelWrapper { - static_assert(sizeof(FuncType) == -1, + // The reason we're not just doing straight up static_assert(false, ...) here: + // Basically, the way to make sure a static_assert only fires if a template + // is actually instantiated (rather than every time the file is parsed) is to use + // template parameters in the expression, e.g. FuncType here. However, since + // `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same + // effect. + static_assert(sizeof(FuncType) != sizeof(FuncType), "Function signature contains one or more unsupported parameter and/or return types. " "Look for a nearby error like " "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " diff --git a/aten/src/ATen/core/builtin_function.h b/aten/src/ATen/core/builtin_function.h index b4804cfebcbe..3d7f70d86877 100644 --- a/aten/src/ATen/core/builtin_function.h +++ b/aten/src/ATen/core/builtin_function.h @@ -10,13 +10,19 @@ struct BuiltinOpFunction : public Function { BuiltinOpFunction( c10::QualifiedName qualname, c10::FunctionSchema schema, - std::function callable) + std::function callable, + std::string doc_string = "") : name_(std::move(qualname)), callable_(std::move(callable)), - schema_(std::move(schema)) { + schema_(std::move(schema)), + doc_string_(std::move(doc_string)) { TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); } + const std::string& doc_string() const override { + return doc_string_; + } + bool isGraphFunction() const override { return false; } @@ -110,6 +116,8 @@ struct BuiltinOpFunction : public Function { std::function callable_; c10::FunctionSchema schema_; + + std::string doc_string_; }; } // namespace jit diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 3ae57341fcf8..a5f9354d7ca2 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -388,9 +388,9 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle::boxArgs(args...); - guard.before(op.schema().name(), stack, seq_num); + guard.before(op, stack, seq_num); } else { - guard.before(op.schema().name(), seq_num); + guard.before(op, seq_num); } } } @@ -438,9 +438,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const seq_num = at::sequence_number::peek(); } if (guard.needs_inputs) { - guard.before(op.schema().name(), *stack, seq_num); + guard.before(op, *stack, seq_num); } else { - guard.before(op.schema().name(), seq_num); + guard.before(op, seq_num); } } } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 97651c9865a1..97ac200d79ef 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -83,13 +83,15 @@ std::list::iterator OperatorEntry::registerKernel( // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { if (cpp_signature_.has_value()) { - TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_, - "Tried to register a kernel (", debug, ") for operator ", name_," for dispatch key ", toString(dispatch_key), - ", but the C++ function signature ", cpp_signature->name(), " mismatched with a previous kernel that had the signature ", - cpp_signature_->name() + TORCH_INTERNAL_ASSERT(*cpp_signature == cpp_signature_->signature, + "Tried to register a kernel (", debug, ") for operator ", name_," (", + (this->schema_.has_value() ? this->schema_->debug : "no debug info"), + ") for dispatch key ", toString(dispatch_key), ", but the C++ function signature ", + cpp_signature->name(), " mismatched with a previous kernel (", cpp_signature_->debug, + ") that had the signature ", cpp_signature_->signature.name() ); } else { - cpp_signature_ = *cpp_signature; + cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug }; } } @@ -103,7 +105,12 @@ std::list::iterator OperatorEntry::registerKernel( auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math]; if (k.size() > 0) { - TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); + TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " (", + (this->schema_.has_value() ? this->schema_->debug : "no debug info"), + ") for dispatch key ", toString(dispatch_key), + " that overwrote a previously registered kernel (", + (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), + ") with the same dispatch key for the same operator."); } if (manuallyBoxedKernel_.has_value()) { diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index ed4d5f40b97f..26506cb0f76f 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -157,13 +157,15 @@ class CAFFE2_API OperatorEntry final { // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template void assertSignatureIsCorrect() { - TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == *cpp_signature_), + TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature), "Tried to access operator ", name_, " with a wrong signature. Accessed with ", CppSignature::make().name(), " but the operator was registered with ", - cpp_signature_->name(), - " (", + cpp_signature_->signature.name(), + " (schema: ", (schema_.has_value() ? schema_->debug : "unknown debug info"), + ", kernel: ", + cpp_signature_->debug, ") This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call." ); } @@ -230,12 +232,17 @@ class CAFFE2_API OperatorEntry final { AnnotatedKernel missingKernel_; static const AnnotatedKernel ambiguousAutogradOtherKernel_; - // signature_hash_ is set to the hash of the function signature if any of + // cpp_signature_ stores function signature if any of // the kernels was created in a way that allowed us to know the function // signature (i.e. by supplying an unboxed C++ kernel function). - // If this is set, it will be used in unboxed function calls + // If this is set, it will be used to check that future kernel + // registrations match and it will be used in unboxed function calls // to verify their arguments against the known function signature. - c10::optional cpp_signature_; + struct CppSignatureWithDebug { + CppSignature signature; + std::string debug; + }; + c10::optional cpp_signature_; // Whether this operator needs to be observed with RecordFunction const bool is_observed_; diff --git a/aten/src/ATen/core/function.h b/aten/src/ATen/core/function.h index 0cf658b0f701..8264bc57e8e8 100644 --- a/aten/src/ATen/core/function.h +++ b/aten/src/ATen/core/function.h @@ -25,6 +25,11 @@ TORCH_API void preoptimizeGraph(std::shared_ptr& graph); // execution of the function. Method is a wrapper around an // underlying Function that also provides a `self` object. struct TORCH_API Function { + virtual const std::string& doc_string() const { + static const std::string no_doc_string = ""; + return no_doc_string; + } + virtual bool isGraphFunction() const = 0; virtual void run(Stack& stack) = 0; diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index dd099be59dff..c29ff15c2b59 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -56,7 +56,7 @@ namespace c10 { _(prim, ReturnStmt) \ _(prim, BreakStmt) \ _(prim, ContinueStmt) \ - _(prim, LocalVariableScope) \ + _(prim, ListComprehensionScope) \ _(prim, Store) \ _(prim, AutogradZero) \ _(prim, AutogradAnyNonZero) \ @@ -70,6 +70,7 @@ namespace c10 { _(prim, ListConstruct) \ _(prim, ListUnpack) \ _(prim, DictConstruct) \ + _(prim, ModuleDictIndex) \ _(prim, EnumName) \ _(prim, EnumValue) \ _(prim, StringIndex) \ @@ -129,7 +130,7 @@ namespace c10 { _(prim, fork) \ _(prim, forkClosure) \ _(prim, RaiseException) \ - _(prim, Function) \ + _(prim, Closure) \ _(prim, CreateObject) \ _(prim, SetAttr) \ _(prim, GetAttr) \ @@ -268,6 +269,8 @@ namespace c10 { _(aten, bin) \ _(aten, pop) \ _(aten, insert) \ + _(aten, vstack) \ + _(aten, row_stack) \ _(prim, unchecked_unwrap_optional) \ _(aten, __contains__) \ _(prim, BailoutTemplate) \ diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index a1b21ee1ba21..c1da20221e62 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1989,7 +1989,8 @@ struct CAFFE2_API ClassType : public NamedType { static ClassTypePtr create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module = false); + bool is_module = false, + std::string doc_string = ""); bool operator==(const Type& rhs) const override { if (auto user_rhs = rhs.cast()) { @@ -2101,6 +2102,13 @@ struct CAFFE2_API ClassType : public NamedType { // valid again. void unsafeRemoveAttribute(const std::string& name); + // [Internal Only] Change the type of an attribute of the ClassType, + // The caller is responsible to make sure the modification is safe: + // it is unsafe to maintain uses of the old type of the attribute, + // and any code that works on the attribute is now invalid. + // Only newly created code is valid again. + void unsafeChangeAttributeType(const std::string& name, TypePtr new_ty); + // Add attribute \p NAME if it doesn't exist or verify that it has a // compatible type otherwise. size_t addOrCheckAttribute( @@ -2175,6 +2183,9 @@ struct CAFFE2_API ClassType : public NamedType { return constantNames_[slot]; } + const std::string& doc_string() const { + return doc_string_; + } IValue getConstant(const std::string& name) const; @@ -2271,7 +2282,8 @@ struct CAFFE2_API ClassType : public NamedType { ClassType( c10::optional name, std::weak_ptr cu, - bool is_module); + bool is_module, + std::string doc_string); std::string annotation_str_impl(TypePrinter printer = nullptr) const override { const auto& n = name().value(); @@ -2306,6 +2318,9 @@ struct CAFFE2_API ClassType : public NamedType { std::vector properties_; bool isModule_ = false; + + // Doc string of class. + std::string doc_string_ = ""; }; struct InterfaceType; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 93b0ffc1b88e..3fd8740d1ab1 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -310,7 +310,8 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis std::string output = testing::internal::GetCapturedStderr(); EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); EXPECT_THAT(output, testing::HasSubstr("CPU")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); + EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); } TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) { @@ -348,7 +349,8 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegistering_then std::string output = testing::internal::GetCapturedStderr(); EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); EXPECT_THAT(output, testing::HasSubstr("catch all")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); + EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); } TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) { @@ -701,7 +703,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(c10::DispatchKey::CPU) .kernel(c10::DispatchKey::CUDA, &called_kernel)); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) { @@ -944,7 +946,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) { @@ -953,7 +955,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBacke expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) { @@ -962,7 +964,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchA expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .catchAllKernel([] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -989,7 +991,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingC m.impl("dummy", DispatchKey::CPU, [] (int64_t) {}); expectThrows([&] { m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {}); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) { diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h index c8437e924a3c..26d5533244d7 100644 --- a/aten/src/ATen/core/op_registration/op_whitelist.h +++ b/aten/src/ATen/core/op_registration/op_whitelist.h @@ -36,7 +36,9 @@ namespace impl { // returns true iff whitelist contains item // op_whitelist_contains("a;bc;d", "bc") == true constexpr bool op_whitelist_contains(string_view whitelist, string_view item) { - size_t next = -1; + //Choose a really big value for next so that if something goes wrong + //this code will blow up in a hopefully detectable way. + size_t next = std::numeric_limits::max(); for (size_t cur = 0; cur <= whitelist.size(); cur = next) { next = whitelist.find(';', cur); if (next != string_view::npos) { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 634d706091af..67b7899bb22f 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1211,19 +1211,21 @@ InterfaceType::~InterfaceType() = default; ClassTypePtr ClassType::create( c10::optional qualifiedName, std::weak_ptr cu, - bool is_module) { + bool is_module, + std::string doc_string) { return ClassTypePtr( - new ClassType(std::move(qualifiedName), std::move(cu), is_module)); + new ClassType(std::move(qualifiedName), std::move(cu), is_module, std::move(doc_string))); } ClassType::ClassType( c10::optional name, std::weak_ptr cu, - bool is_module = false) + bool is_module = false, + std::string doc_string = "") : NamedType(TypeKind::ClassType, std::move(name)), compilation_unit_(std::move(cu)), - isModule_(is_module) { -} + isModule_(is_module), + doc_string_(std::move(doc_string)) {} const std::vector& ClassType::methods() const { return methods_; @@ -1313,6 +1315,14 @@ void ClassType::unsafeRemoveAttribute(const std::string& name) { AT_ASSERT(attributes_.size() == attributeTypes_.size()); } +void ClassType::unsafeChangeAttributeType(const std::string& name, TypePtr new_ty) { + auto slot = getAttributeSlot(name); + auto old_attr_info = attributes_[slot]; + AT_ASSERT(old_attr_info.getKind() == AttributeKind::REGULAR_ATTRIBUTE); + attributes_[slot] = ClassAttribute(old_attr_info.getKind(), new_ty, old_attr_info.getName()); + attributeTypes_[slot] = new_ty; +} + size_t ClassType::addConstant(const std::string& name, const IValue& value) { checkNotExist(name, "constant"); size_t slot = constantNames_.size(); diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index edce0e3a2cce..807a9d9780f0 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -394,6 +394,13 @@ struct Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + Vec256 ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } Vec256 neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 37d41676e53c..10bbe139b63f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -290,6 +290,25 @@ template <> class Vec256 { auto o2 = _mm256_loadu_ps(tmp2); return cvtfp32_bf16(o1, o2); } + Vec256 igamma(const Vec256 &x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } Vec256 log() const { return map(Sleef_logf8_u10); } diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h index d2ae6f46b44e..d7f5afd8b67d 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -252,6 +252,9 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_pd(); return _mm256_sub_pd(zero, values); diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h index 8b4eba07f421..4df95dbea926 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -290,6 +290,9 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_ps(); return _mm256_sub_ps(zero, values); diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index fcad154e68b2..6b611e8d2e7a 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -155,6 +155,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ double tmp[size()]; + __at_align32__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return Vec256(Sleef_logd4_u10(values)); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index 1ab11ea81529..d83895fdf854 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -193,6 +193,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 neg() const { return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec256/vec256_float_neon.h index f98c645a08d6..f410e415277f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec256/vec256_float_neon.h @@ -25,6 +25,8 @@ namespace { // https://bugs.llvm.org/show_bug.cgi?id=45824 // Most likely we will do aarch32 support with inline asm. #if defined(__aarch64__) +// See https://github.com/pytorch/pytorch/issues/47098 +#if defined(__clang__) || (__GNUC__ > 8 || (__GNUC__ == 8 && __GNUC_MINOR__ > 3)) #ifdef __BIG_ENDIAN__ #error "Big endian is not supported." @@ -362,6 +364,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return map(std::log); } @@ -665,6 +677,7 @@ Vec256 inline fmadd(const Vec256& a, const Vec256& b, const return Vec256(r0, r1); } -#endif +#endif /* defined(__clang__) || (__GNUC__ > 8 || (__GNUC__ == 8 && __GNUC_MINOR__ > 3)) */ +#endif /* defined(aarch64) */ }}} diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index 80e39c6bc6bc..b82e04fbe1d6 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -79,6 +79,11 @@ const char *cusparseGetErrorString(cusparseStatus_t status); #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) +// This should be used directly after every kernel launch to ensure +// the launch happened correctly and provide an early, close-to-source +// diagnostic if it didn't. +#define TORCH_CUDA_KERNEL_LAUNCH_CHECK() AT_CUDA_CHECK(cudaGetLastError()) + // For CUDA Driver API // // This is here instead of in c10 because NVRTC is loaded dynamically via a stub diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index b2d8df49f51b..45ceddcd94e8 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index e7e5659babbb..16f706ca0ed5 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -331,6 +331,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { auto batch_size = batchCount(A); auto n = A.size(-2); auto nrhs = b.size(-1); + auto lda = std::max(int64_t{1}, n); auto ipiv = at::empty({n}, b.options().dtype(kInt)); auto ipiv_data = ipiv.data_ptr(); @@ -339,7 +340,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { for (int64_t i = 0; i < batch_size; i++) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; - lapackSolve(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info); + lapackSolve(n, nrhs, A_working_ptr, lda, ipiv_data, b_working_ptr, lda, &info); infos[i] = info; if (info != 0) { return; diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f8af756773c9..b7916ba3f9c8 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -46,6 +46,7 @@ DEFINE_DISPATCH(logaddexp2_stub); DEFINE_DISPATCH(gcd_stub); DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); +DEFINE_DISPATCH(igamma_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); @@ -968,6 +969,23 @@ Tensor& hypot_(Tensor& self, const Tensor& other) { return at::hypot_out(self, self, other); } +Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return result; +} + +Tensor igamma(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& igamma_(Tensor& self, const Tensor& other) { + return at::igamma_out(self, self, other); +} + Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other); nextafter_stub(iter.device_type(), iter); diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 7640c8bd84ac..ee3f023fedc5 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -10,7 +10,7 @@ namespace at { namespace native { inline void alpha_check(const ScalarType dtype, Scalar alpha) { TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool, "Boolean alpha only supported for Boolean results."); - TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) + TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) || alpha.isIntegral(true), "For integral input tensors, argument alpha must not be a floating point number."); } @@ -68,6 +68,7 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub); DECLARE_DISPATCH(binary_fn, gcd_stub); DECLARE_DISPATCH(binary_fn, lcm_stub); DECLARE_DISPATCH(binary_fn, hypot_stub); +DECLARE_DISPATCH(binary_fn, igamma_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index f48ad9a4a6cb..6dbf1e5535ed 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -237,7 +237,11 @@ auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) c (input.options().backend() == at::Backend::CPU && input.scalar_type() == kFloat && // only on CPU Float Tensors !transposed && // or transposed tensors - (groups > 1 || weight.size(2) > 3 || input.size(0) > 1 + (is_strided() || is_dilated() || input.size(0) >= 16 || + weight.size(-1) != 1 || weight.size(-2) != 1) && + (groups > 1 + || (weight.size(-1) > 3 && weight.size(-2) > 3) + || input.size(0) > 1 || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480)); // for some case, native is faster #endif return false; diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 0430de87eb77..360069998f19 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -131,7 +132,11 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) } if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) { + #ifdef USE_VULKAN_API + return vulkan::ops::copy_(self, src); + #else return at::vulkan::vulkan_copy_(self, src); + #endif } if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index dc21a505e8c1..63e2462489be 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -3,7 +3,9 @@ #include #include #include + #include +#include // Implements instruction set specific function dispatch. // diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b2b760513a1d..91d804687290 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -27,7 +27,7 @@ Tensor pdist(const Tensor& self, const double p) { Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) { /** This function does the fist part of the euclidean distance calculation - * We divide it in two steps to simplify dealing with subgradients in the + * We divide it in two steps to simplify dealing with subgradients in the * backward step */ Tensor x1_norm = x1.pow(2).sum(-1, true); Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -74,7 +74,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10 std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies()); + const int64_t expand_batch_product = prod_intlist(expand_batch_portion); std::vector tensor1_view{expand_batch_product, r1, c1}; std::vector tensor2_view{expand_batch_product, r2, c2}; @@ -147,8 +147,10 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); IntArrayRef batch_tensor1(x1.sizes().data(), std::max(x1.dim() - 2, 0)); - int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies()); - Tensor grad_x1 = at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT).view({batch_product, n, m}); + const int64_t batch_product = prod_intlist(batch_tensor1); + Tensor grad_x1 = + at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT) + .view({batch_product, n, m}); cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist); return grad_x1; } diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 3f250ae09909..6589a33ed2f4 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -17,16 +17,27 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding", indices_arg, kLong); + auto zerofill_padding = [&](Tensor& embedding) { + if (padding_idx >= 0) { + embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } + }; + // TODO: use tensor.index() after improving perf if (indices.dim() == 1) { - return weight.index_select(0, indices); + auto out = weight.index_select(0, indices); + zerofill_padding(out); + return out; } auto size = indices.sizes().vec(); for (auto d : weight.sizes().slice(1)) { size.push_back(d); } - return weight.index_select(0, indices.reshape(-1)).view(size); + + auto out = weight.index_select(0, indices.reshape(-1)); + zerofill_padding(out); + return out.view(size); } Tensor embedding_backward( diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index 73f7dcd61926..b466ca26fc0c 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -4,20 +4,12 @@ #include #include #include +#include namespace at { namespace native { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -namespace { - template - inline void fill_fast(Tensor& self, Scalar value_scalar) { - auto value = value_scalar.to(); - scalar_t * dptr = static_cast(self.data_ptr()); - *dptr = value; - } -} // namspace - Tensor& fill_out(Tensor& self, Scalar value) { if (self.is_quantized()) { at::Tensor out = at::ones(self.sizes()).to(kFloat) * value; @@ -26,15 +18,8 @@ Tensor& fill_out(Tensor& self, Scalar value) { self.copy_(out); return self; } - // When filling a number to 1-element CPU tensor, we want to skip - // everything but manipulate data ptr directly. - // Ideally this fast pass should be implemented in TensorIterator, - // but we also want to skip compute_types which in not avoidable - // in TensorIterator for now. if (self.device() == at::kCPU && self.numel() == 1 && !self.is_complex() && !value.isComplex()) { - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { - fill_fast(self, value);}); - return self; + return at::detail::scalar_fill(self, value); } auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 7a8213dc6fcd..1777f00ee68a 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -675,8 +675,8 @@ Tensor matmul( std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), - 1, std::multiplies()); + const int64_t expand_batch_product = + prod_intlist(expand_batch_portion); std::vector tensor1_bmm_view({expand_batch_product}); tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1}); @@ -742,7 +742,7 @@ Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) { {n_copies, a.size(0), a.size(1), a.size(2)}, a.options().memory_format(at::MemoryFormat::Contiguous) ); - + if (is_zero) { res.zero_(); } @@ -850,7 +850,7 @@ Tensor compute_T4(const Tensor& A) { auto As = _allocate_buffer(A, 4); // 3 for {I, A, A^2} _fill_matrix_powers(As, A, 3); - + at::native::matmul( // output for A^2 * (I / 2 + A / 6 + A^2 / 24) As.select(0, 3), @@ -1101,7 +1101,7 @@ Tensor mexp_impl( if (!compute_highest_degree_approx) { constexpr std::array< Tensor(*)(const Tensor&), - total_n_degs - 1> + total_n_degs - 1> compute_Ts = { compute_T1, compute_T2, compute_T4, compute_T8, compute_T12 @@ -1192,7 +1192,7 @@ Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) { // Based on: // -// Mathias, Roy. +// Mathias, Roy. // A Chain Rule for Matrix Functions and Applications. // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620. // @@ -1227,8 +1227,8 @@ Tensor backward_analytic_function_of_a_matrix( // Mathematics 2019, 7, 1174. // Tensor matrix_exp(const Tensor& a) { - TORCH_CHECK(a.dim() >= 2 - && (at::isFloatingType(a.scalar_type()) + TORCH_CHECK(a.dim() >= 2 + && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())), "matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor " "of floating or complex types with dim at least 2"); @@ -1391,9 +1391,8 @@ static std::vector make_dim_list(int64_t ndim) { } // Checks for valid arguments to linalg_norm when type(ord) == str -static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { +static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim) { TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); - TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument"); bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); TORCH_CHECK(dims_valid, "order \"", str_ord, "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); @@ -1553,14 +1552,15 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional if (opt_str_ord.has_value()) { // 'ord' is string auto str_ord = opt_str_ord.value(); - check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); + check_str_ord_valid(str_ord, opt_dim, ndim); + Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self; if (str_ord == "fro") { - at::frobenius_norm_out(result, self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + at::frobenius_norm_out(result, self_, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); } else if (str_ord == "nuc") { if (opt_dim.has_value()) { - at::nuclear_norm_out(result, self, opt_dim.value(), keepdim); + at::nuclear_norm_out(result, self_, opt_dim.value(), keepdim); } else { - at::nuclear_norm_out(result, self, keepdim); + at::nuclear_norm_out(result, self_, keepdim); } } } else { @@ -1602,6 +1602,55 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } +Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + /* + The idea is to reduce the problem to 2D matrix solve. + Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right. + For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2), + then the result of permutation would have the shape (2, 4, 1, 3). + Step 2. reshape `self` to 2D matrix. + Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D() + Step 4. reshape the result. + */ + int64_t ndim = self.dim(); + Tensor self_ = self; + + // move dimensions of `self_` from `dims` to the end + if (dims.has_value()) { + DimVector dest_axes(dims.value().size()); + std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size()); + self_ = at::movedim(self_, dims.value(), dest_axes); + } + + // result_shape is self_.sizes[-(an-other.dim):] + std::vector result_shape = self_.sizes().slice(other.dim(), ndim - other.dim()).vec(); + + int64_t result_product = std::accumulate(result_shape.begin(), result_shape.end(), int64_t{1}, std::multiplies()); + int64_t other_product = std::accumulate(other.sizes().begin(), other.sizes().end(), int64_t{1}, std::multiplies()); + + // Check whether the self tensor can be reshaped to the 2D square matrix + TORCH_CHECK(result_product == other_product, + "Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ", + result_product, " != ", other_product); + + self_ = self_.reshape({result_product, result_product}); + + // 0th output of at::solve is the solution + // normally `other` would be flattened by at::solve expects 2D input + Tensor result = std::get<0>(at::solve(other.reshape({other.numel(), 1}), self_)); + return result.reshape(result_shape); +} + +Tensor& linalg_tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::linalg_tensorsolve(self, other, dims); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index c00ffec94119..dc5530a72813 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -381,6 +381,716 @@ static inline float calc_polygamma(int64_t n, float x) { zeta(double(n + 1), x); } +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et. al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * M_PIf * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +c10::Half calc_igamma(c10::Half a, c10::Half x) { + return calc_igamma(float(a), float(x)); +} + inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } template diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp index 645822f55065..682af63cdafe 100644 --- a/aten/src/ATen/native/MaxPooling.cpp +++ b/aten/src/ATen/native/MaxPooling.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -98,10 +99,11 @@ Tensor max_pool1d( IntArrayRef dilation, bool ceil_mode) { if (self.is_quantized()) { - return at::quantized_max_pool1d(self, kernel_size, stride, padding, - dilation, ceil_mode); + return at::quantized_max_pool1d( + self, kernel_size, stride, padding, dilation, ceil_mode); } - if (self.requires_grad() || !self.device().is_cpu()) { + if ((self.requires_grad() && at::GradMode::is_enabled()) || + !self.device().is_cpu()) { // Needs indices for grad and with_indices defines CUDA dispatch return std::get<0>(at::max_pool1d_with_indices( self, kernel_size, stride, padding, dilation, ceil_mode)); diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index f8f0231b181c..a7042b283c4c 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -14,7 +14,7 @@ Tensor empty_meta( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + TensorOptions options = options_.merge_memory_format(optional_memory_format); // TODO: deduplicate this logic with empty_cpu diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 459dd857727f..e80b0c546362 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -181,10 +182,8 @@ void slow_conv_dilated_all_cpu_template( // Temporary buffer: Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t n = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t n = prod_intlist(output_size); columns.resize_({nInputPlane * m, n}); } // Initialize diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index b89554fd4d48..071460b090cd 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -28,11 +28,12 @@ static inline T pooling_output_shape_pad_lr( T outputSize = div_rtn( inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (ceil_mode ? stride - 1 : 0), stride) + 1; - if (pad_l) { + if (ceil_mode) { // ensure that the last pooling starts inside the image // needed to avoid problems in ceil mode - if ((outputSize - 1) * stride >= inputSize + pad_l) + if ((outputSize - 1) * stride >= inputSize + pad_l) { --outputSize; + } } return outputSize; } diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index db33de9a9475..ca5d1848a4b8 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -31,13 +31,11 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { "result type ", common_dtype, "can't be cast to the desired output type ", result.scalar_type()); - if (exp.isComplex() && (exp.toComplexDouble() == 0.0) ) { - result.resize_as_(base).fill_(1); - } else if (exp.isComplex() && (exp.toComplexDouble() == 1.0) ) { - result.resize_as_(base).fill_(base); - } else if (!exp.isComplex() && (exp.toDouble() == 0.0)) { + auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble(); + + if (exponent == 0.0) { result.resize_as_(base).fill_(1); - } else if (!exp.isComplex() && (exp.toDouble() == 1.0)) { + } else if (exponent == 1.0) { result.resize_as_(base).copy_(base); } else { auto iter = TensorIterator::unary_op(result, base.to(common_dtype)); diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 8fdc977092f4..be61ffb8b546 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -73,7 +73,7 @@ static inline void checkInBoundsForStorage( IntArrayRef size, IntArrayRef stride, int64_t storage_offset, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, const Storage& new_storage) { int64_t storage_size_bytes = detail::computeStorageNbytes(size, stride, data_type.itemsize()); diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 58df4cf110f7..d941f3b8e169 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -29,10 +29,15 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b return self; } + bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() && + (options.layout() == c10::kStrided)); + if (memory_format == MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { // Copy all strides - auto r = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt)); + auto r = at::empty_strided(self.sizes(), + self.strides(), + options.memory_format(c10::nullopt).pinned_memory(pin_out)); r.copy_(self, non_blocking); return r; } else { @@ -40,7 +45,9 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b } } // See Note [Explicit nullopt MemoryFormat argument] - auto r = at::empty(self.sizes(), options.memory_format(memory_format), c10::nullopt); + auto r = at::empty(self.sizes(), + options.memory_format(memory_format).pinned_memory(pin_out), + c10::nullopt); r.copy_(self, non_blocking); return r; } @@ -56,7 +63,7 @@ Tensor to( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK(options.requires_grad_opt() == c10::nullopt, "to(options) expects unset requires_grad flag, but got " diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 82d7363a1b32..0cec2dd32b0e 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -166,44 +166,7 @@ Tensor polar(const Tensor& abs, const Tensor& angle) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional optional_memory_format) { - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); - - AT_ASSERT(options.device().type() == DeviceType::CPU); - check_size_nonnegative(size); - - c10::Allocator* allocator; - if (options.pinned_memory()) { - allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); - } else { - allocator = at::getCPUAllocator(); - } - - int64_t nelements = prod_intlist(size); - auto dtype = options.dtype(); - int64_t size_bytes = nelements * dtype.itemsize(); - auto storage_impl = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size_bytes, - allocator->allocate(size_bytes), - allocator, - /*resizeable=*/true); - - auto tensor = detail::make_tensor( - std::move(storage_impl), at::DispatchKey::CPU, dtype); - // Default TensorImpl has size [0] - if (size.size() != 1 || size[0] != 0) { - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); - } - - auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); - - return tensor; + return at::detail::empty_cpu(size, options_, optional_memory_format); } Tensor empty( @@ -277,7 +240,7 @@ Tensor empty_like( TensorOptions options = self.options() .merge_in(options_) - .merge_in(TensorOptions().memory_format(optional_memory_format)); + .merge_memory_format(optional_memory_format); TORCH_CHECK( !(options.layout() != kStrided && @@ -381,7 +344,8 @@ Tensor new_empty( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor eye(int64_t n, const TensorOptions& options) { - return native::eye(n, -1, options); + // the default value of `m` equals to `n` + return native::eye(n, n, options); } Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { @@ -390,15 +354,13 @@ Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { } Tensor& eye_out_cpu(Tensor& result, int64_t n) { - return native::eye_out_cpu(result, n, -1); + // the default value of `m` equals to `n` + return native::eye_out_cpu(result, n, n); } Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index f551adcec693..8cae202efe13 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -61,11 +61,7 @@ inline void check_args( } } -inline void check_size_nonnegative(IntArrayRef size) { - for (auto x: size) { - TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); - } -} +using at::check_size_nonnegative; inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) { TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(), diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a42e90f399d9..14fb67e5d4ba 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -368,7 +368,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) { // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting. int64_t values_dim = wrapped - sparse_dim + 1; // The final size along the catted dimension. - int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) { + const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast(0), [values_dim](int64_t l, Tensor const &r) { return l + r._values().size(values_dim); }); auto zeros_sizes = tensors[0]._values().sizes().vec(); @@ -1262,6 +1262,47 @@ static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t di return self; } +// torch.row_stack, alias for torch.vstack +Tensor& row_stack_out(Tensor& result, TensorList tensors) { + return at::vstack_out(result, tensors); +} + +Tensor row_stack(TensorList tensors) { + return at::vstack(tensors); +} + +static std::vector reshape_input_for_column_stack(TensorList tensors) { + std::vector result(tensors.size()); + auto transform_lambda = [](const Tensor& input) -> Tensor { + // reshape 0D or 1D tensor t into (t.numel(), 1) + if (input.dim() <= 1) { + return input.reshape({input.numel(), 1}); + } + return input; + }; + std::transform(tensors.cbegin(), + tensors.cend(), + result.begin(), + transform_lambda); + return result; +} + +Tensor& column_stack_out(Tensor& result, TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack_out(result, reshaped_tensors); +} + +Tensor column_stack(TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack(reshaped_tensors); +} + static Tensor& propagate_transposed_names( Tensor& result, const Tensor& other, @@ -1634,7 +1675,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty"); TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size()); - auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const int64_t numel = prod_intlist(sizes); if (self.has_names()) { TORCH_CHECK(numel == self.size(dim), "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 99ec62e972c6..1aebfda85da0 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -284,20 +284,20 @@ Tensor& i0_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(re Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); } Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); } -Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log_stub); } -Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); } +Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log_stub); } +Tensor log(const Tensor& self) { return unary_op_impl_float(self, log_stub); } Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); } -Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log10_stub); } -Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); } +Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log10_stub); } +Tensor log10(const Tensor& self) { return unary_op_impl_float(self, log10_stub); } Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); } Tensor& log1p_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log1p_stub); } Tensor log1p(const Tensor& self) { return unary_op_impl(self, at::log1p_out); } Tensor& log1p_(Tensor& self) { return unary_op_impl_(self, at::log1p_out); } -Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log2_stub); } -Tensor log2(const Tensor& self) { return unary_op_impl(self, at::log2_out); } +Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, log2_stub); } +Tensor log2(const Tensor& self) { return unary_op_impl_float(self, log2_stub); } Tensor& log2_(Tensor& self) { return unary_op_impl_(self, at::log2_out); } Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); } @@ -316,7 +316,11 @@ Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); } Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); } -Tensor& sign_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sign_stub); } +Tensor& sign_out(Tensor& result, const Tensor& self) { + TORCH_CHECK(!self.is_complex(), + "Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead."); + return unary_op_impl_out(result, self, sign_stub); +} Tensor sign(const Tensor& self) { return unary_op_impl(self, at::sign_out); } Tensor& sign_(Tensor& self) { return unary_op_impl_(self, at::sign_out); } @@ -335,8 +339,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); } Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); } -Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); } -Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); } +Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); } +Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); } Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); } Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); } @@ -448,8 +452,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out( Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); } Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); } -Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); } -Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); } +Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); } +Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); } Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); } Tensor& trunc_out(Tensor& result, const Tensor& self) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index fce8c348919b..652f3ee063e1 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -589,17 +589,31 @@ void logit_backward_kernel(TensorIterator& iter, Scalar eps_scalar) { } void tanh_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { - auto one_vec = Vec256(scalar_t{1}); + if (isComplexType(iter.dtype())) { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { + auto one_vec = Vec256(scalar_t{1}); cpu_kernel_vec( iter, [=](scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t{1} - b * b); + return a * std::conj(scalar_t{1} - b * b); }, [=](Vec256 a, Vec256 b) { - return a * (one_vec - b * b); + return a * (one_vec - b * b).conj(); }); }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { + auto one_vec = Vec256(scalar_t{1}); + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1} - b * b); + }, + [=](Vec256 a, Vec256 b) { + return a * (one_vec - b * b); + }); + }); + } } void mse_kernel(TensorIterator& iter) { @@ -752,6 +766,19 @@ void hypot_kernel(TensorIterator& iter) { }); } +void igamma_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }, + [=](Vec256 a, Vec256 b) { + return a.igamma(b); + }); + }); +} + void nextafter_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() { cpu_kernel_vec( @@ -810,6 +837,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel); REGISTER_DISPATCH(gcd_stub, &gcd_kernel); REGISTER_DISPATCH(lcm_stub, &lcm_kernel); REGISTER_DISPATCH(hypot_stub, &hypot_kernel); +REGISTER_DISPATCH(igamma_stub, &igamma_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 145fc990daeb..9926d8b05d95 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -341,12 +341,10 @@ namespace { void GeluCUDAKernelImpl(TensorIterator& it) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluCUDAKernelImpl", [&] { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - return static_cast(x) * - c10::cuda::compat::normcdf(static_cast(x)); - }); + using T_ACC = acc_type; + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + return static_cast(x) * + c10::cuda::compat::normcdf(static_cast(x)); }); }); } @@ -354,17 +352,15 @@ void GeluCUDAKernelImpl(TensorIterator& it) { void GeluBackwardCUDAKernelImpl(TensorIterator& it) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "GeluBackwardCUDAKernelImpl", [&] { - using T_ACC = acc_type; - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5); - const T_ACC cdf = c10::cuda::compat::normcdf(static_cast(x)); - const T_ACC pdf = - c10::cuda::compat::exp( - T_ACC(-0.5) * static_cast(x) * static_cast(x)) * - kBeta; - return static_cast(dy) * (cdf + static_cast(x) * pdf); - }); + using T_ACC = acc_type; + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5); + const T_ACC cdf = c10::cuda::compat::normcdf(static_cast(x)); + const T_ACC pdf = + c10::cuda::compat::exp( + T_ACC(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); }); }); } @@ -389,68 +385,70 @@ void leaky_relu_backward_kernel(TensorIterator& iter, Scalar negval_) { void hardswish_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardswish_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t one_sixth(1.0f / 6.0f); - const scalar_t three(3.0f); - const scalar_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - return self_val * std::min(std::max(self_val + three, zero), six) * one_sixth; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC one_sixth(1.0f / 6.0f); + const T_ACC three(3.0f); + const T_ACC six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + T_ACC x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; }); }); } void hardswish_backward_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardswish_backward_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t three(3.0f); - const scalar_t neg_three(-3.0f); - const scalar_t one_half(0.5f); - gpu_kernel( - iter, - [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - if (self_val < neg_three) { - return zero; - } else if (self_val <= three) { - return grad_val * ((self_val / three) + one_half); - } else { - return grad_val; - } - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC three(3.0f); + const T_ACC neg_three(-3.0f); + const T_ACC one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + T_ACC grad_val = static_cast(grad_val_); + T_ACC self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } }); }); } void hardsigmoid_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardsigmoid_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t one_sixth(1.0f / 6.0f); - const scalar_t three(3.0f); - const scalar_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - return std::min(std::max(self_val + three, zero), six) * one_sixth; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC one_sixth(1.0f / 6.0f); + const T_ACC three(3.0f); + const T_ACC six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + T_ACC x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; }); }); } void hardsigmoid_backward_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardsigmoid_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardsigmoid_backward_cuda", [&] { - const scalar_t zero(0.0f); - const scalar_t three(3.0f); - const scalar_t neg_three(-3.0f); - const scalar_t one_sixth(1.0f / 6.0f); - gpu_kernel( - iter, - [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - return (self_val >= neg_three && self_val <= three) - ? grad_val * one_sixth - : zero; - }); + using T_ACC = acc_type; + const T_ACC zero(0.0f); + const T_ACC three(3.0f); + const T_ACC neg_three(-3.0f); + const T_ACC one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + T_ACC grad_val = static_cast(grad_val_); + T_ACC self_val = static_cast(self_val_); + return (self_val >= neg_three && self_val <= three) + ? grad_val * one_sixth + : zero; }); }); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 88bad0a919b2..318185e43e8a 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -840,12 +840,13 @@ AT_ERROR("solve: MAGMA library not found in " auto b_data = b.data_ptr(); magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + magma_int_t lda = std::max(magma_int_t{1}, n); if (b.dim() == 2) { auto ipiv = at::empty({n}, at::kInt); magma_int_t info = 0; - magmaSolve(n, nrhs, A_data, n, ipiv.data_ptr(), - b_data, n, &info); + magmaSolve(n, nrhs, A_data, lda, ipiv.data_ptr(), + b_data, lda, &info); infos[0] = info; } else { auto A_mat_stride = matrixStride(A); @@ -885,7 +886,7 @@ AT_ERROR("solve: MAGMA library not found in " magma_int_t* info_array_cur = &info_array[mini_idx]; magmaSolveBatched( - n, nrhs, A_array_cur, n, ipiv_array_cur, b_array_cur, n, + n, nrhs, A_array_cur, lda, ipiv_array_cur, b_array_cur, lda, info_array_cur, batch_limit, magma_queue); } @@ -893,7 +894,7 @@ AT_ERROR("solve: MAGMA library not found in " // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaSolveBatched( - n, nrhs, &A_array[mini_idx], n, &ipiv_array[mini_idx], &b_array[mini_idx], n, + n, nrhs, &A_array[mini_idx], lda, &ipiv_array[mini_idx], &b_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue); } diff --git a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu index 9b7bc28a829e..ed7e2190f75e 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu @@ -60,13 +60,21 @@ void logit_backward_kernel_cuda(TensorIterator& iter, Scalar eps_scalar) { } void tanh_backward_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] { + if(isComplexType(iter.dtype())) { + AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_complex_cuda", [&]() { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t{1.} - b * b); + return a * std::conj(scalar_t{1.} - b * b); }); }); - }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "tanh_backward_cuda", [&]() { + AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "tanh_backward_cuda", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1.} - b * b); + }); + }); + }); + } } REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index fc9aa74f91f4..2f53c2bb08d7 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -92,6 +92,14 @@ void hypot_kernel_cuda(TensorIterator& iter) { }); } +void igamma_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }); + }); +} + void nextafter_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { @@ -102,7 +110,7 @@ void nextafter_kernel_cuda(TensorIterator& iter) { void heaviside_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(), "heaviside_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a == 0 ? b : static_cast(a > 0); }); }); @@ -116,6 +124,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda); REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda); REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); +REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu index be3f4f0bb01e..f80d0906dfa2 100644 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu @@ -65,7 +65,7 @@ void div_kernel_cuda(TensorIterator& iter) { } void mul_kernel_cuda(TensorIterator& iter) { - if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && + if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) { //if common dtype is half the scalar constant can overflow in half precision, and yet the result can //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index eec428ae2a12..ddf3679b4c27 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -54,12 +54,12 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { a = q; i = 0; b = 0.0; - while((i < 9) || (a <= 9.0)){ + while ((i < 9) || (a <= 9.0)) { i += 1; a += 1.0; b = ::pow( a, -x ); s += b; - if((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { + if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { return static_cast(s); } }; @@ -68,16 +68,16 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { s -= 0.5 * b; a = 1.0; k = 0.0; - for(int i=0; i < 12; i++) { + for (int i=0; i < 12; i++) { a *= x + k; b /= w; t = a * b / A[i]; s = s + t; t = t / s; - if(t < 0){ + if (t < 0){ t = -t; } - if((-MACHEP (s); } k += 1.0; @@ -174,6 +174,503 @@ static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) { return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast(n) + 1.0)) * zeta(static_cast(n + 1), x); } +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +// regularized lower & upper incomplete gamma +template +static __host__ __device__ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + using accscalar_t = at::acc_type; + int64_t i, dir; + accscalar_t y, num_ans, denom_ans; + accscalar_t absx = ::fabs(x); + const accscalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +template +static __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + using accscalar_t = at::acc_type; + + static const accscalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0 + }; + return ratevl(static_cast(x), lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + using accscalar_t = at::acc_type; + accscalar_t ax, fac, res, num, numfac; + static accscalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static accscalar_t EXP1 = 2.718281828459045; + static accscalar_t lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + + using accscalar_t = at::acc_type; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + accscalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + using accscalar_t = at::acc_type; + int n; + accscalar_t fac = 1; + accscalar_t sum = 0; + accscalar_t term, logx; + static accscalar_t MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (::fabs(term) <= MACHEP * ::fabs(sum)) { + break; + } + } + + logx = ::log(x); + term = -::expm1(a * logx - ::lgamma(1+a)); + return term - ::exp(a * logx - ::lgamma(a)) * sum; +} + +template +static __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + + using accscalar_t = at::acc_type; + static const accscalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + accscalar_t lambda = x / a; + accscalar_t sigma = (x - a) / a; + accscalar_t eta, res, ck, ckterm, term, absterm; + accscalar_t absoldterm = INFINITY; + accscalar_t etapow[25] = {1}; + accscalar_t sum = 0; + accscalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); + + return res; +} + +template +static __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + + using accscalar_t = at::acc_type; + int i; + accscalar_t ans, ax, c, yc, r, t, y, z; + accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static accscalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static accscalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = ::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static inline __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (::isinf(static_cast(x))) { + return 0.0; + } + + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / ::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (::isinf(static_cast(x))) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. */ + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +// end of regularized lower & upper incomplete gamma template static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { diff --git a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu index 2ad6f0785a17..522e3bbd8760 100644 --- a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu +++ b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu @@ -188,10 +188,8 @@ void slow_conv_dilated_all_cuda_template( int64_t nInputPlane = weight.size(1); int64_t nOutputPlane = weight.size(0); // Temporary buffers: - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t output_vsize = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t output_vsize = prod_intlist(output_size); Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { columns.resize_({nInputPlane * m, output_vsize}); diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 38f3f8487fc4..4286f05111b6 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -207,62 +207,60 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "arange_cuda", [&]() { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "arange_cuda", [&] { - using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - // we use double precision for (start - end) / step - // to compute size_d for consistency across devices. - // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, - // but double on cpu for the same, - // and the effective output size starts differing on CPU vs GPU because of precision issues, which - // we dont want. - // the corner-case we do want to take into account is int64_t, which has higher precision than double - double size_d; - if (std::is_same::value) { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } else { - size_d = std::ceil(static_cast(end.to() - start.to()) - / step.to()); - } - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); - - TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), - "invalid size, possible overflow?"); - int64_t size = static_cast(size_d); - int64_t numel = result.numel(); - - if (numel != size) { - if(numel > 0){ - TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), - " is ", numel, " which does not match the computed number of elements ", size, - ". Note that this may occur as a result of rounding error. " - "The out tensor will be resized to a tensor of shape (", size, ",)."); - } - result.resize_({size}); - } - bool is_contiguous = result.is_contiguous(); - Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); - gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { - accscalar_t inc = xstep * static_cast(ind); - accscalar_t val = xstart + inc; - return static_cast(val); - }); + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, + // but double on cpu for the same, + // and the effective output size starts differing on CPU vs GPU because of precision issues, which + // we dont want. + // the corner-case we do want to take into account is int64_t, which has higher precision than double + double size_d; + if (std::is_same::value) { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } else { + size_d = std::ceil(static_cast(end.to() - start.to()) + / step.to()); + } - if(!is_contiguous) { - result.copy_(r); + TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); + TORCH_CHECK(std::isfinite(static_cast(xstart)) && + std::isfinite(static_cast(xend)), + "unsupported range: ", xstart, " -> ", xend); + TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), + "upper bound and larger bound inconsistent with step sign"); + + TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + int64_t size = static_cast(size_d); + int64_t numel = result.numel(); + + if (numel != size) { + if(numel > 0){ + TORCH_WARN("The number of elements in the out tensor of shape ", result.sizes(), + " is ", numel, " which does not match the computed number of elements ", size, + ". Note that this may occur as a result of rounding error. " + "The out tensor will be resized to a tensor of shape (", size, ",)."); } + result.resize_({size}); + } + bool is_contiguous = result.is_contiguous(); + Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; + + gpu_kernel_with_index(r, [xstart, xstep]GPU_LAMBDA(int64_t ind) -> scalar_t { + accscalar_t inc = xstep * static_cast(ind); + accscalar_t val = xstart + inc; + return static_cast(val); }); + + if(!is_contiguous) { + result.copy_(r); + } }); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 6bc2c381e1db..b0dc71c568ba 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -128,16 +128,16 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se */ template __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_, - int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) { - for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *self = self_ + orow * row_size * num_irows + irow; scalar_t *values = values_ + orow * row_size * num_irows + irow; int64_t *indices = indices_ + orow * row_size * num_irows + irow; scalar_t out = init; int64_t out_idx = 0; - for (int64_t col = 0; col < row_size; ++col) { + for (auto col = decltype(row_size){0}; col < row_size; ++col) { if(THCNumerics::isnan(*self) || (!THCNumerics::isnan(out) && binary_op(*self, out))) { out = *self; out_idx = col; @@ -152,21 +152,34 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala } } +void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + template __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, int dim, scalar_t init, BinaryFunction binary_op) { - int row_size = self.size(dim); + int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + dim3 threads(std::min(512, int(num_irows))); - int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x)))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); tensor_kernel_scan_outer_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_orows, num_irows, row_size, init, binary_op); @@ -254,16 +267,16 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int */ template __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - scalar_t init, BinaryOp binary_op) + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) { - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *src = src_ + orow * row_size * num_irows + irow; scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; scalar_t acc = init; - for (unsigned col = 0; col < row_size; ++col) { + for (uint32_t col = 0; col < row_size; ++col) { acc = binary_op(acc, *src); *tgt = acc; @@ -286,12 +299,12 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, */ template __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_, - unsigned num_rows, unsigned row_size, + const uint32_t num_rows, const uint32_t row_size, T init, BinaryFunction binary_op){ - for (unsigned block_row = blockIdx.x * blockDim.y; + for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; + uint32_t row = block_row + threadIdx.y; T block_total = init; T *row_src = src_ + row * row_size; @@ -299,10 +312,10 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr // Perform scan on one block at a time, keeping track of the total value of // all blocks processed so far. - for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { // Load data into shared memory (two values per thread). - unsigned col1 = block_col + threadIdx.x; - unsigned col2 = block_col + num_threads_x + threadIdx.x; + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; if (row < num_rows) { if (col1 < row_size) { row_buf[threadIdx.x] = row_src[col1]; @@ -324,18 +337,18 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr __syncthreads(); // Parallel reduction (up-sweep). - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { if (row < num_rows && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; + uint32_t offset = (2 * threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } // Down-sweep. - for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { if (row < num_rows && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); @@ -361,8 +374,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { __shared__ T sbuf[num_threads_y][2 * num_threads_x]; @@ -381,8 +394,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { // As we cannot directly initialize shared array for complex types @@ -399,23 +412,18 @@ tensor_kernel_scan_innermost_dim( row_buf, tgt_, src_, num_rows, row_size, init, binary_op); } -void check_fits_in_unsigned(int64_t val, const char* name) { - constexpr auto umax = std::numeric_limits::max(); - TORCH_CHECK( - val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value"); -} template __host__ void scan_outer_dim(const Tensor& self, Tensor& result, int dim, scalar_t init, BinaryFunction binary_op) { - int64_t row_size = self.size(dim); + const int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); dim3 threads(std::min(512, int(num_irows))); int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 5da007507905..64af6cb268a2 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -237,7 +237,12 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } stackInputs[batchCounter].input = inputs[i+batchCounter].data_ptr(); @@ -338,7 +343,12 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } catMetaData.input[batchCounter] = inputs[i+batchCounter].data_ptr(); catMetaData.offset[batchCounter] = offset; catMetaData.dimSize[batchCounter] = dimSize; @@ -431,7 +441,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { auto should_skip = [](const Tensor &t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - bool hasSkippedInput = false; const Tensor *notSkippedTensor = NULL; // non-owning reference int nDims = 0; @@ -452,10 +461,8 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { } at::assert_no_internal_overlap(out); - for (int i = 0; i < inputs.size(); i++) - { + for (int i = 0; i < inputs.size(); i++) { if (should_skip(inputs[i])) { - hasSkippedInput = true; continue; } nDims = inputs[i].dim(); @@ -501,11 +508,10 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { // We parallelize the copy if all 6 conditions pass: // // 1. There is more than one input tensor - // 2. No empty inputs - // 3. The out tensor is 32-bit indexable - // 4. The number of dimensions is <= 4 - // 5. All input tensors are contiguous (output tensor may be non-contig) - // 6. All input tensors can use 32-bit indexing + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(), [] (const Tensor& t) { @@ -522,7 +528,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { }); allSameType = allSameType && (out.scalar_type() == firstType); if (inputs.size() > 1 && - !hasSkippedInput && out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && at::cuda::detail::canUse32BitIndexMath(out) && allContiguous && diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 0a5760580c06..c6688b286914 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -319,7 +319,7 @@ std::tuple median_with_indices_impl( NoNamesGuard guard; dim = at::maybe_wrap_dim(dim, self.dim()); - Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0); int64_t size = in.size(dim); TORCH_CHECK( diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh index 54513955e912..0e5cb7371d58 100644 --- a/aten/src/ATen/native/cuda/SortingCommon.cuh +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) { } +// WARNING: This function assumes input tensors are contiguous template void run_launcher( Tensor& values, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index fb0eb4ca8b09..13f0b53516de 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -22,15 +22,13 @@ namespace at { namespace native { Tensor& eye_out_cuda(Tensor& result, int64_t n) { - return at::native::eye_out_cuda(result, n, /*m=*/-1); + // the default value of `m` equals to `n` + return at::native::eye_out_cuda(result, n, n); } Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 465e54db51d6..2f7e92f3fc2e 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) { } void cos_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) { } void tan_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryLogKernels.cu b/aten/src/ATen/native/cuda/UnaryLogKernels.cu index a43fa541554b..44e73173af17 100644 --- a/aten/src/ATen/native/cuda/UnaryLogKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryLogKernels.cu @@ -11,7 +11,7 @@ namespace at { namespace native { void log_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log(a); }); @@ -19,7 +19,7 @@ void log_kernel_cuda(TensorIterator& iter) { } void log10_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log10_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log10(a); }); @@ -35,7 +35,7 @@ void log1p_kernel_cuda(TensorIterator& iter) { } void log2_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "log2_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log2(a); }); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index c3c8dd1e5094..4b1f0c1a6aa3 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -223,13 +223,16 @@ void nan_to_num_kernel_cuda( }); } -void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta){ +void kaiser_window_kernel_cuda(TensorIterator& iter, int64_t window_length, double beta_){ AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){ - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "kaiser_window_cuda", [&] { - const scalar_t alpha = static_cast((window_length - 1) / 2.0); - gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t { - return calc_i0(static_cast(beta) * ::sqrt(1 - ::pow((a - alpha) / alpha, static_cast(2.0)))) / calc_i0(static_cast(beta)); - }); + using T_ACC = acc_type; + const T_ACC inv_alpha = static_cast(2.0 / (window_length - 1)); + const T_ACC beta = static_cast(beta_); + const T_ACC inv_i0_beta = 1.0 / calc_i0(beta); + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t { + T_ACC x = static_cast(a) * inv_alpha - 1; + T_ACC y = std::max(0, 1 - x * x); + return calc_i0(beta * ::sqrt(y)) * inv_i0_beta; }); }); } diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 759167095ae3..3bd4daac917b 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -106,11 +106,8 @@ Tensor group_norm( input.sizes()); const auto input_shape = input.sizes(); - const int64_t HxW = std::accumulate( - input_shape.cbegin() + 2, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t HxW = + prod_intlist(input_shape.cbegin() + 2, input_shape.cend()); const Tensor kEmpty; const auto& X = input.is_contiguous() ? input : input.contiguous(); diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index bf931fb26c5f..fa936ab7d4ce 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -52,16 +52,10 @@ std::tuple _prepare_layer_norm_inputs( } const int axis = input_ndim - normalized_ndim; - const int64_t M = std::accumulate( - input_shape.cbegin(), - input_shape.cbegin() + axis, - 1LL, - std::multiplies()); - const int64_t N = std::accumulate( - input_shape.cbegin() + axis, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t M = + prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); const auto& X = input.is_contiguous() ? input : input.contiguous(); const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); diff --git a/aten/src/ATen/native/metal/MetalTensor.mm b/aten/src/ATen/native/metal/MetalTensor.mm index 6dfe3932bf16..b1fc38d92a6b 100644 --- a/aten/src/ATen/native/metal/MetalTensor.mm +++ b/aten/src/ATen/native/metal/MetalTensor.mm @@ -17,7 +17,7 @@ class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl { _numel(std::accumulate( std::begin(_sizes), std::end(_sizes), - 1, + (int64_t)1, std::multiplies())), _textureImpl(std::make_unique(sizes)) {} diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 3f6e78e77c9f..27e119d377bc 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -468,7 +468,6 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) { if (args.params.deterministic && !benchmark) { *algo = search::DEFAULT_ALGO; - return; } if (cache.find(args.params, algo)) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 74e28e7f58b8..696551a12a99 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -283,13 +283,13 @@ use_c10_dispatcher: full variants: function dispatch: - DefaultBackend: view_as_real + CPU, CUDA: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) use_c10_dispatcher: full variants: function dispatch: - DefaultBackend: view_as_complex + CPU, CUDA: view_as_complex - func: sgn(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1546,6 +1546,16 @@ - func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) use_c10_dispatcher: full +# row_stack is the alias of vstack +- func: row_stack(Tensor[] tensors) -> Tensor + use_c10_dispatcher: full + dispatch: + Math: row_stack + +- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: row_stack_out + - func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -6505,6 +6515,21 @@ dispatch: DefaultBackend: hypot_ +- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: igamma_out + +- func: igamma(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: method, function + dispatch: + CPU, CUDA: igamma + +- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: igamma_ + - func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: nextafter_out @@ -8658,6 +8683,15 @@ CPU: col2im_backward_cpu CUDA: col2im_backward_cuda +- func: column_stack(Tensor[] tensors) -> Tensor + use_c10_dispatcher: full + dispatch: + Math: column_stack + +- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: column_stack_out + - func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: @@ -8884,6 +8918,18 @@ python_module: linalg variants: function +- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorsolve + +- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorsolve_out + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp index 11b84fa4713d..a05026d49f46 100644 --- a/aten/src/ATen/native/quantized/TensorFactories.cpp +++ b/aten/src/ATen/native/quantized/TensorFactories.cpp @@ -20,7 +20,7 @@ Tensor empty_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); @@ -42,7 +42,7 @@ Tensor empty_per_channel_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 4d95ce4ffb4c..91c895685fd3 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -60,6 +60,30 @@ void CopyToChannelsLast3dTensor( } } +template +void CopyICFirst3dTensorToChannelsLast3dTensor( + int64_t G, + int64_t IC_G, + int64_t OC_G, + int64_t D, + int64_t H, + int64_t W, + const T* src, + T* dst) { + // IC OC/G THW -> G OC/G THW IC/G + const int64_t inner_size = D * H * W; + for (int64_t i = 0; i < G * OC_G; ++i) { + for (int64_t j = 0; j < inner_size; ++j) { + for (int64_t ic = 0; ic < IC_G; ++ic) { + int g = i / OC_G; + int oc = i % OC_G; + dst[(i * inner_size + j) * IC_G + ic] = + src[((g * IC_G + ic) * OC_G + oc) * inner_size + j]; + } + } + } +} + } // namespace template @@ -256,6 +280,75 @@ template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( const std::vector& dilations, const std::vector& output_padding, bool transposed); +template <> +Tensor TransposeConvTensorUnpackConversion<3>(const Tensor& src, int groups) { + // OC IC/G DHW -> IC OC/G DHW logically + auto oc_g_ic_g_hw_tensors = src.chunk(groups); + auto fused_tensor = + at::cat(oc_g_ic_g_hw_tensors, 1).set_quantizer_(src.quantizer()); + return fused_tensor.permute({1, 0, 2, 3, 4}); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<2>( + const at::Tensor& src, + int groups, + bool transpose) { + return transpose ? + // 2D conv transpose weight transform + // IC OC/G KH KW -> G OC/G KH KW IC/G + [&]() { + auto ic_g_oc_g_hw_tensors = src.chunk(groups); + for (auto& tensor : ic_g_oc_g_hw_tensors) { + tensor = tensor.unsqueeze(0); + } + auto fused_tensor = + at::cat(ic_g_oc_g_hw_tensors).set_quantizer_(src.quantizer()); + return fused_tensor.permute({0, 2, 3, 4, 1}) + .contiguous(c10::MemoryFormat::Contiguous); + }() + // 2d conv weight transform + : src.contiguous(c10::MemoryFormat::ChannelsLast); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<3>( + const at::Tensor& src, + int groups, + bool transpose) { + if (!transpose) { + return ConvertToChannelsLast3dTensor(src); + } else { + TORCH_CHECK(src.dim() == 5); + Tensor dst; + const int64_t N = src.size(0); + const int64_t IC_G = N / groups; + const int64_t OC_G = src.size(1); + const int64_t D = src.size(2); + const int64_t H = src.size(3); + const int64_t W = src.size(4); + dst = MakeStridedQTensorCPU( + {groups * OC_G, IC_G, D, H, W}, + {D * H * W * IC_G, 1, H * W * IC_G, W * IC_G, IC_G}, + src.options(), + src.quantizer()); + AT_DISPATCH_QINT_TYPES( + src.scalar_type(), "CopyICFirst3dTensorToChannelsLast3dTensor", [&]() { + const Tensor src_contig = src.contiguous(); + CopyICFirst3dTensorToChannelsLast3dTensor( + groups, + IC_G, + OC_G, + D, + H, + W, + src_contig.data_ptr(), + dst.data_ptr()); + }); + return dst; + } +} + } // namespace fbgemm_utils } // namespace native } // namespace at @@ -263,8 +356,9 @@ template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( #endif // USE_FBGEMM -template -CAFFE2_API torch::class_> register_conv_params() { + template + CAFFE2_API torch::class_> + register_conv_params() { static auto register_conv_params = torch::class_>( "quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase") diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 40ef0feba61e..0cccf81e35d8 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -295,6 +295,11 @@ Tensor TransposeConvTensorUnpackConversion( const Tensor& src, int groups); +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); } // namespace fbgemm_utils } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index a12718502dd1..0b9bc6b8e901 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -243,6 +243,15 @@ Tensor qadd_scalar(Tensor qa, Scalar b) { return _add_scalar_out(qc, qa, b); } +template +Tensor qadd_scalar2(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Add."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _add_scalar_out(qc, qa, b); +} + template Tensor qadd_scalar_out(Tensor qa, Scalar b, Tensor out) { check_inputs(qa, out); @@ -269,10 +278,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar2"), TORCH_FN(qadd_scalar2)); m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar_out"), TORCH_FN(qadd_scalar_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu"), TORCH_FN(qadd)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.out"), TORCH_FN(qadd_out)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar"), TORCH_FN(qadd_scalar)); + m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar2"), TORCH_FN(qadd_scalar2)); m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar_out"), TORCH_FN(qadd_scalar_out)); // deprecated functions, kept for backward compatibility m.impl(TORCH_SELECTIVE_NAME("quantized::add_out"), TORCH_FN(qadd_out)); diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index e96bd26acaba..05762bfb036f 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -411,7 +411,7 @@ at::Tensor PackedConvWeight::apply_impl( output_shape = MakeDeConvOutputShape( N, M, - {H, W}, + kSpatialDim == 2 ? std::vector{H, W} : std::vector{D, H, W}, kernel, stride(), padding(), @@ -886,6 +886,9 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // transpose m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d"), QConvInt8<2, false>::run); + m.impl( + TORCH_SELECTIVE_NAME("quantized::conv_transpose3d"), + QConvInt8<3, false>::run); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index af400558b5fe..c3b20163502d 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -39,9 +39,6 @@ c10::intrusive_ptr> PackedConvWeight< padding.size() == kSpatialDim, "Specify front/top/left padding only. " "end/bottom/right padding assumed to be equal to front/top/left"); - TORCH_CHECK( - !(transpose && kSpatialDim == 3), - "Currently no support for 3d conv_transpose in FBGEM. "); TORCH_CHECK( !transpose || output_padding.size() == kSpatialDim, "quantized::conv_prepack: Specify top/left output padding " @@ -104,26 +101,10 @@ c10::intrusive_ptr> PackedConvWeight< // for both conv and conv transpose // but PyTorch lays them out as {out_c, in_c/groups, kH, kW} // (or for ConvTranspose {in_c, out_c/groups, kH, kW}) - const at::Tensor weight_nhwc = transpose - ? - // check transpose - // 2D conv transpose weight transform - // IC OC/G KH KW -> OC KH KW IC/G - // transpose does not support 3d yet. - [&]() { - auto ic_g_oc_g_hw_tensors = weight.chunk(groups); - auto fused_tensor = - at::cat(ic_g_oc_g_hw_tensors, 1).set_quantizer_(weight.quantizer()); - return fused_tensor.permute({1, 2, 3, 0}) - .contiguous(c10::MemoryFormat::Contiguous); - }() - : (kSpatialDim == 2 - // 2d conv weight transform - ? weight.contiguous(c10::MemoryFormat::ChannelsLast) - // 3d conv weight transform - : at::native::fbgemm_utils::ConvertToChannelsLast3dTensor(weight)); + const at::Tensor weight_nhwc = + at::native::fbgemm_utils::ConvertConvWeightsToChannelLastTensor(weight, groups, transpose); const int8_t* weight_data_int8 = - reinterpret_cast(weight_nhwc.data_ptr()); + reinterpret_cast(weight_nhwc.data_ptr()); std::vector col_offsets(output_channels); // compute column offsets (Similar to // fbgemm::col_offsets_with_zero_pt_s8acc32_ref) please note that offsets @@ -444,6 +425,7 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // ConvTranspose m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { @@ -452,6 +434,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { // ConvTranspose m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index a908a0b77732..484bfe44fc76 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -13,10 +13,6 @@ template std::tuple> PackedConvWeight< kSpatialDim>::unpack() { auto* packed_weights_p = w.get(); - TORCH_CHECK( - !(kSpatialDim != 2 && transpose()), - "FBGEMM does not support 3d unpack right " - "now."); // output channels const int output_channels = packed_weights_p->outputChannels(); const int input_channels = packed_weights_p->inputChannels(); @@ -91,7 +87,7 @@ std::tuple> PackedConvWeight< if(transpose()){ unpacked_weights = at::native::fbgemm_utils::TransposeConvTensorUnpackConversion< - 2>(unpacked_weights, groups); + kSpatialDim>(unpacked_weights, groups); } return std::tuple>( unpacked_weights, bias); @@ -276,6 +272,7 @@ TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { // ConvTranspose is the same, however, we want to have different name. m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_unpack"), TORCH_FN(QConv1dUnpackWeightsInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<3>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_stride"), TORCH_FN(QConvStride<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_padding"), TORCH_FN(QConvPadding<2>::run)); @@ -283,6 +280,12 @@ TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_dilation"), TORCH_FN(QConvDilation<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_groups"), TORCH_FN(QConvGroups<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_transpose"), TORCH_FN(QConvTranspose<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_stride"), TORCH_FN(QConvStride<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_padding"), TORCH_FN(QConvPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_output_padding"), TORCH_FN(QConvOutputPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_dilation"), TORCH_FN(QConvDilation<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_groups"), TORCH_FN(QConvGroups<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_transpose"), TORCH_FN(QConvTranspose<3>::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index deeae36dc502..dccc2d718bf1 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -136,6 +136,18 @@ class QMulScalar final { } }; +template +class QMulScalar2 final { + public: + static Tensor run(Scalar b, Tensor qa) { + TORCH_CHECK(qa.qscheme() == kPerTensorAffine || + qa.qscheme() == kPerTensorSymmetric, + "Only per tensor quantization is supported in Mul."); + auto qc = at::empty_like(qa, qa.suggest_memory_format()); + return _mul_scalar_out(qc, qa, b); + } +}; + template class QMulScalarOut final { public: @@ -176,10 +188,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::mul"), TORCH_FN(QMul::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"), TORCH_FN(QMulOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar2"), TORCH_FN(QMulScalar2::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar_out"), TORCH_FN(QMulScalarOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu"), TORCH_FN(QMul::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.out"), TORCH_FN(QMulOut::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar"), TORCH_FN(QMulScalar::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar2"), TORCH_FN(QMulScalar2::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar_out"), TORCH_FN(QMulScalarOut::run)); // deprecated functions, kept for backward compatibility m.impl(TORCH_SELECTIVE_NAME("quantized::mul_out"), TORCH_FN(QMulOut::run)); diff --git a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp index 6ed193cd82c9..c8bbe9d29b24 100644 --- a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp @@ -71,11 +71,8 @@ Tensor quantized_group_norm_impl( const int64_t batches = input_shape[0]; const int64_t num_channels = input_shape[1]; - const int64_t elements_per_batch = std::accumulate( - input_shape.cbegin() + 1, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t elements_per_batch = + prod_intlist(input_shape.cbegin() + 1, input_shape.cend()); const int64_t M = batches * num_groups; const int64_t N = elements_per_batch / num_groups; diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 3b866cf2fd12..c09501deec91 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -23,9 +23,11 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar(Tensor qa, Scalar b) -> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar2(Scalar b, Tensor qa) -> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out")); // deprecated functions, kept for backward compatibility @@ -95,6 +97,7 @@ TORCH_LIBRARY(quantized, m) { // conv_tranpsose m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); @@ -105,6 +108,14 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_output_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_transpose(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_prepack(Tensor weight) -> __torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack")); @@ -142,10 +153,12 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar(Tensor qa, Scalar b)-> Tensor qc")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar2(Scalar b, Tensor qa)-> Tensor qc")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out")); // deprecated functions, kept for backward compatibility m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out")); @@ -178,6 +191,7 @@ TORCH_LIBRARY(_quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); diff --git a/aten/src/ATen/native/vulkan/Vulkan.cpp b/aten/src/ATen/native/vulkan/Vulkan.cpp index d6fa6a32291b..d58d7d7dcc09 100644 --- a/aten/src/ATen/native/vulkan/Vulkan.cpp +++ b/aten/src/ATen/native/vulkan/Vulkan.cpp @@ -7,6 +7,7 @@ #include #include +#include #ifdef USE_VULKAN_WRAPPER #include @@ -1182,11 +1183,7 @@ class VulkanTensor::Impl final { explicit Impl(std::vector sizes) : sizes_(std::move(sizes)), strides_(std::vector(sizes_.size())), - numel_(std::accumulate( - std::begin(sizes_), - std::end(sizes_), - 1, - std::multiplies())) { + numel_(prod_intlist(sizes_)) { TORCH_CHECK( initVulkanContextOnce(), "Vulkan Failed to create Vulkan Context"); } @@ -1289,8 +1286,7 @@ class VulkanTensor::Impl final { VkDeviceSize buffer_size_for_sizes(std::vector sizes) const { const auto d = sizes.size(); - const auto numel = std::accumulate( - std::begin(sizes), std::end(sizes), 1, std::multiplies()); + const auto numel = prod_intlist(sizes); VkDeviceSize bufferSize{sizeof(float) * numel}; // alignment to be able to copy between image and buffer if (d == 4) { diff --git a/aten/src/ATen/native/vulkan/VulkanAten.cpp b/aten/src/ATen/native/vulkan/VulkanAten.cpp index 72d5e15208ec..6a781c3ab69b 100644 --- a/aten/src/ATen/native/vulkan/VulkanAten.cpp +++ b/aten/src/ATen/native/vulkan/VulkanAten.cpp @@ -537,6 +537,8 @@ Tensor mean( return new_with_vtensor_vulkan(std::move(output), self.options()); } +#ifndef USE_VULKAN_API + TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl("slice.Tensor", TORCH_FN(at::native::vulkan::aten::slice)); m.impl("reshape", TORCH_FN(at::native::vulkan::aten::reshape)); @@ -570,6 +572,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl_UNBOXED("add_.Tensor", at::native::vulkan::aten::add_); } +#endif /* USE_VULKAN_API */ + Tensor& copy_from_vulkan_(Tensor& self, const Tensor& src) { TORCH_INTERNAL_ASSERT( src.device().type() == DeviceType::Vulkan, diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 6afc28676f2b..fc4f9945fcaf 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -11,7 +11,7 @@ template struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { VulkanOpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes, diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp index 302525582c9d..0e13dce41a7c 100644 --- a/aten/src/ATen/native/vulkan/VulkanOps.cpp +++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -553,21 +554,19 @@ void add( void add(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; + int32_t inputSize[3]; float s; }; ConstBlock cb{{safe_downcast(W), safe_downcast(H), - safe_downcast(C_4), - 0}, + safe_downcast(C_4)}, s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -606,21 +605,19 @@ void add(VulkanTensor& output, const VulkanTensor& input, const float s) { void mul(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; + int32_t inputSize[3]; float s; }; ConstBlock cb{{safe_downcast(W), safe_downcast(H), - safe_downcast(C_4), - 0}, + safe_downcast(C_4)}, s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); diff --git a/aten/src/ATen/native/vulkan/api/Adapter.h b/aten/src/ATen/native/vulkan/api/Adapter.h index 0efc08010ef5..4ba02a5e9926 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.h +++ b/aten/src/ATen/native/vulkan/api/Adapter.h @@ -4,6 +4,7 @@ #include #include +#include namespace at { namespace native { @@ -30,6 +31,10 @@ struct Adapter final { // for now. return VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU == properties.deviceType; } + + inline Shader::WorkGroup local_work_group_size() const { + return { 8u, 8u, 1u, }; + } }; } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index cdf96f88f639..3066e815ebed 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace at { namespace native { @@ -69,6 +70,10 @@ VkCommandBuffer allocate_command_buffer( } // namespace +Command::Buffer::Buffer() + : command_buffer_(VK_NULL_HANDLE) { +} + Command::Buffer::Buffer( const VkDevice device, const VkCommandPool command_pool) @@ -113,19 +118,28 @@ void Command::Buffer::barrier( "Potential reason: This command buffer is moved from."); c10::SmallVector global_memory_barriers; - c10::SmallVector image_memory_barriers; + c10::SmallVector image_memory_barriers; - for (const Resource::Buffer::Barrier& barrier : barrier.buffers) { + if (!barrier.buffers.empty()) { // Using global memory barriers instead of buffer memory barriers for // buffers. The consensus seems to be that there is no advantage in - // using the latter. + // using the latter in favor of the former. - global_memory_barriers.push_back({ - VK_STRUCTURE_TYPE_MEMORY_BARRIER, - nullptr, - barrier.memory.src, - barrier.memory.dst, - }); + VkMemoryBarrier global_memory_barrier{ + VK_STRUCTURE_TYPE_MEMORY_BARRIER, + nullptr, + 0u, + 0u, + }; + + // Coalesce all buffer memory barriers into one global memory barrier. + + for (const Resource::Buffer::Barrier& barrier : barrier.buffers) { + global_memory_barrier.srcAccessMask |= barrier.memory.src; + global_memory_barrier.dstAccessMask |= barrier.memory.dst; + } + + global_memory_barriers.push_back(global_memory_barrier); } for (const Resource::Image::Barrier& barrier : barrier.images) { @@ -248,15 +262,17 @@ void Command::Buffer::dispatch( "This command buffer is in an invalid state! " "Potential reason: This command buffer is moved from."); - static const auto div_round_up = [](const uint32_t n, const uint32_t d) { - return (n + d - 1u) / d; - }; - vkCmdDispatch( command_buffer_, - div_round_up(global_work_group.x, bound_.pipeline.local_work_group.x), - div_round_up(global_work_group.y, bound_.pipeline.local_work_group.y), - div_round_up(global_work_group.z, bound_.pipeline.local_work_group.z)); + div_up( + global_work_group.width, + bound_.pipeline.local_work_group.width), + div_up( + global_work_group.height, + bound_.pipeline.local_work_group.height), + div_up( + global_work_group.depth, + bound_.pipeline.local_work_group.depth)); } void Command::Buffer::submit( diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index 38f251ba4d2b..8e2f235cfa27 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -20,6 +20,7 @@ struct Command final { class Buffer final { public: + Buffer(); Buffer(VkDevice device, VkCommandPool command_pool); Buffer(const Buffer&) = delete; Buffer& operator=(const Buffer&) = delete; @@ -27,6 +28,8 @@ struct Command final { Buffer& operator=(Buffer&&); ~Buffer() = default; + operator bool() const; + void begin(); void end(); void barrier(const Pipeline::Barrier& barrier); @@ -91,6 +94,10 @@ inline Command::Buffer& Command::Buffer::operator=(Buffer&& buffer) { return *this; } +inline Command::Buffer::operator bool() const { + return VK_NULL_HANDLE != command_buffer_; +} + } // namespace api } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index f28d8819626b..ef638f917956 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -132,11 +132,12 @@ Context* context() { Descriptor::Set dispatch_prologue( Command::Buffer& command_buffer, const Shader::Layout::Signature& shader_layout_signature, - const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& local_work_group) { - Descriptor& descriptor = context()->descriptor(); - Pipeline& pipeline = context()->pipeline(); - Shader& shader = context()->shader(); + const Shader::Descriptor& shader_descriptor) { + Context* const context = api::context(); + const GPU gpu = context->gpu(); + Descriptor& descriptor = context->descriptor(); + Pipeline& pipeline = context->pipeline(); + Shader& shader = context->shader(); const Shader::Layout::Object shader_layout = shader.layout.cache.retrieve({ @@ -149,7 +150,7 @@ Descriptor::Set dispatch_prologue( shader_layout.handle, }), shader.cache.retrieve(shader_descriptor), - local_work_group, + gpu.adapter->local_work_group_size(), })); return descriptor.pool.allocate(shader_layout); diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h index 95239a1e3042..687ddbbfe931 100644 --- a/aten/src/ATen/native/vulkan/api/Context.h +++ b/aten/src/ATen/native/vulkan/api/Context.h @@ -47,7 +47,6 @@ class Context final { Command::Buffer& command_buffer, const Shader::Layout::Signature& shader_layout_signature, const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& local_work_group, const Shader::WorkGroup& global_work_group, Arguments&&... arguments); @@ -140,22 +139,19 @@ inline void Context::dispatch( Command::Buffer& command_buffer, const Shader::Layout::Signature& shader_layout_signature, const Shader::Descriptor& shader_descriptor, - const Shader::WorkGroup& local_work_group, const Shader::WorkGroup& global_work_group, Arguments&&... arguments) { // Forward declaration Descriptor::Set dispatch_prologue( Command::Buffer&, const Shader::Layout::Signature&, - const Shader::Descriptor&, - const Shader::WorkGroup&); + const Shader::Descriptor&); // Factor out template parameter independent code to minimize code bloat. Descriptor::Set descriptor_set = dispatch_prologue( command_buffer, shader_layout_signature, - shader_descriptor, - local_work_group); + shader_descriptor); detail::bind( descriptor_set, diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 93c954354ab2..9facc3f49e0f 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -101,11 +101,11 @@ typename Pipeline::Factory::Handle Pipeline::Factory::operator()( "Invalid Vulkan shader module!"); constexpr uint32_t x_offset = 0u; - constexpr uint32_t x_size = sizeof(Shader::WorkGroup::x); + constexpr uint32_t x_size = sizeof(Shader::WorkGroup::width); constexpr uint32_t y_offset = x_offset + x_size; - constexpr uint32_t y_size = sizeof(Shader::WorkGroup::y); + constexpr uint32_t y_size = sizeof(Shader::WorkGroup::height); constexpr uint32_t z_offset = y_offset + y_size; - constexpr uint32_t z_size = sizeof(Shader::WorkGroup::z); + constexpr uint32_t z_size = sizeof(Shader::WorkGroup::depth); constexpr VkSpecializationMapEntry specialization_map_entires[3]{ // X diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h index 9ef2bf8c3991..50893b709473 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.h +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -43,8 +43,10 @@ struct Pipeline final { VkPipelineStageFlags dst; } stage; - c10::SmallVector buffers; - c10::SmallVector images; + c10::SmallVector buffers; + c10::SmallVector images; + + operator bool() const; }; // @@ -169,6 +171,13 @@ struct Pipeline final { // Impl // +inline Pipeline::Barrier::operator bool() const { + return (0u != stage.src) || + (0u != stage.dst) || + !buffers.empty() || + !images.empty(); +} + inline bool operator==( const Pipeline::Layout::Descriptor& _1, const Pipeline::Layout::Descriptor& _2) { @@ -193,9 +202,9 @@ inline size_t Pipeline::Factory::Hasher::operator()( return c10::get_hash( descriptor.pipeline_layout, descriptor.shader_module, - descriptor.local_work_group.x, - descriptor.local_work_group.y, - descriptor.local_work_group.z); + descriptor.local_work_group.width, + descriptor.local_work_group.height, + descriptor.local_work_group.depth); } inline Pipeline::Object::operator bool() const { diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index 436288645e95..9f29ff2cbfd5 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -46,13 +46,13 @@ VmaAllocator create_allocator( } VmaAllocationCreateInfo create_allocation_create_info( - const VmaMemoryUsage usage) { + const Resource::Memory::Descriptor& descriptor) { return VmaAllocationCreateInfo{ 0u, /* VMA_ALLOCATION_CREATE_MAPPED_BIT - MoltenVK Issue #175 */ /* VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT */ - usage, - 0u, - 0u, + descriptor.usage, + descriptor.required, + descriptor.preferred, 0u, VK_NULL_HANDLE, nullptr, @@ -101,7 +101,7 @@ void* map(const Resource::Memory& memory) { Resource::Memory::Scope::Scope( const VmaAllocator allocator, const VmaAllocation allocation, - const Access access) + const Access::Flags access) : allocator_(allocator), allocation_(allocation), access_(access) { @@ -121,7 +121,7 @@ void Resource::Memory::Scope::operator()(const void* const data) const { vmaUnmapMemory(allocator_, allocation_); - if (Access::Write == access_) { + if (access_ & Access::Write) { // Call will be ignored by implementation if the memory type this allocation // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior // we want. @@ -177,6 +177,23 @@ Resource::Image::Sampler::Factory::operator()( }; } +VkFence Resource::Fence::handle(const bool add_to_waitlist) const { + if (!pool) { + return VK_NULL_HANDLE; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + id < pool->fence_.pool.size(), + "Invalid Vulkan fence!"); + + const VkFence fence = pool->fence_.pool[id].get(); + if (add_to_waitlist) { + pool->fence_.waitlist.push_back(fence); + } + + return fence; +} + void Resource::Fence::wait(const uint64_t timeout_nanoseconds) { const VkFence fence = handle(/* add_to_waitlist = */ false); diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index afe008c14b20..a9428d272782 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -12,6 +12,7 @@ namespace native { namespace vulkan { namespace api { + struct Resource final { class Pool; @@ -29,22 +30,49 @@ struct Resource final { VkAccessFlags dst; }; + /* + Descriptor + */ + + struct Descriptor final { + VmaMemoryUsage usage; + VkMemoryPropertyFlags /* optional */ required; + VkMemoryPropertyFlags /* optional */ preferred; + }; + VmaAllocator allocator; VmaAllocation allocation; + struct Access final { + typedef uint8_t Flags; + + enum Type : Flags { + Read = 1u << 0u, + Write = 1u << 1u, + }; + + template + using Pointer = std::add_pointer_t< + std::conditional_t< + 0u != (access & Write), + Type, + std::add_const_t>>; + }; + class Scope; template - using Data = Handle; + using Handle = Handle; template< typename Type, - typename Pointer = std::add_pointer_t>> - Data map() const &; + typename Pointer = Access::Pointer> + Handle map() const &; template< typename Type, - typename Pointer = std::add_pointer_t> - Data map() &; + Access::Flags kAccess, + typename Pointer = Access::Pointer> + Handle map() &; private: // Intentionally disabed to ensure memory access is always properly @@ -54,10 +82,10 @@ struct Resource final { // for seemingly ineffective memory writes and hard to hunt down bugs. template - Data map() const && = delete; + Handle map() const && = delete; - template - Data map() && = delete; + template + Handle map() && = delete; }; // @@ -74,7 +102,7 @@ struct Resource final { struct { VkBufferUsageFlags buffer; - VmaMemoryUsage memory; + Memory::Descriptor memory; } usage; }; @@ -171,7 +199,7 @@ struct Resource final { struct { VkImageUsageFlags image; - VmaMemoryUsage memory; + Memory::Descriptor memory; } usage; struct { @@ -241,11 +269,18 @@ struct Resource final { Pool& operator=(Pool&&); ~Pool(); + // Primary + Buffer buffer(const Buffer::Descriptor& descriptor); Image image(const Image::Descriptor& descriptor); Fence fence(); void purge(); + // Helper + + template + Buffer uniform(const Block& block); + private: friend struct Fence; @@ -284,37 +319,42 @@ struct Resource final { class Resource::Memory::Scope final { public: - enum class Access { - Read, - Write, - }; + Scope( + VmaAllocator allocator, + VmaAllocation allocation, + Access::Flags access); - Scope(VmaAllocator allocator, VmaAllocation allocation, Access access); void operator()(const void* data) const; private: VmaAllocator allocator_; VmaAllocation allocation_; - Access access_; + Access::Flags access_; }; template -inline Resource::Memory::Data Resource::Memory::map() const & { +inline Resource::Memory::Handle Resource::Memory::map() const & { void* map(const Memory& memory); - return Data{ + return Handle{ reinterpret_cast(map(*this)), - Scope(allocator, allocation, Scope::Access::Read), + Scope(allocator, allocation, Access::Read), }; } -template -inline Resource::Memory::Data Resource::Memory::map() & { +template +inline Resource::Memory::Handle Resource::Memory::map() & { void* map(const Memory& memory); - return Data{ + static_assert( + (kAccess == Access::Read) || + (kAccess == Access::Write) || + (kAccess == (Access::Read | Access::Write)), + "Invalid memory access!"); + + return Handle{ reinterpret_cast(map(*this)), - Scope(allocator, allocation, Scope::Access::Write), + Scope(allocator, allocation, kAccess), }; } @@ -356,21 +396,29 @@ inline Resource::Fence::operator bool() const { return pool; } -inline VkFence Resource::Fence::handle(const bool add_to_waitlist) const { - if (!pool) { - return VK_NULL_HANDLE; - } - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - id < pool->fence_.pool.size(), - "Invalid Vulkan fence!"); - - const VkFence fence = pool->fence_.pool[id].get(); - if (add_to_waitlist) { - pool->fence_.waitlist.push_back(fence); +template +inline Resource::Buffer Resource::Pool::uniform(const Block& block) { + Buffer uniform = this->buffer({ + sizeof(Block), + { + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + { + VMA_MEMORY_USAGE_CPU_TO_GPU, + 0u, + 0u, + }, + }, + }); + + { + Memory::Handle memory = uniform.memory.template map< + Block, + Memory::Access::Write>(); + + *memory.get() = block; } - return fence; + return uniform; } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index f3e62d9c4d0a..b2238a95de50 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -112,11 +112,7 @@ struct Shader final { // Work Group // - struct WorkGroup final { - uint32_t x; - uint32_t y; - uint32_t z; - }; + typedef VkExtent3D WorkGroup; /* Descriptor @@ -228,9 +224,9 @@ inline void Shader::Layout::Cache::purge() { inline bool operator==( const Shader::WorkGroup& _1, const Shader::WorkGroup& _2) { - return (_1.x == _2.x) && - (_1.y == _2.y) && - (_1.z == _2.z); + return (_1.width == _2.width) && + (_1.height == _2.height) && + (_1.depth == _2.depth); } inline Shader::Descriptor::Descriptor(const char* const glsl) diff --git a/aten/src/ATen/native/vulkan/api/Utils.h b/aten/src/ATen/native/vulkan/api/Utils.h new file mode 100644 index 000000000000..2882d4dc2283 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Utils.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef USE_VULKAN_API + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +inline uint32_t div_up( + const uint32_t numerator, + const uint32_t denominator) { + return (numerator + denominator - 1u) / denominator; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/api.h b/aten/src/ATen/native/vulkan/api/api.h index b46793cbb660..c20d8b71e3c6 100644 --- a/aten/src/ATen/native/vulkan/api/api.h +++ b/aten/src/ATen/native/vulkan/api/api.h @@ -12,5 +12,6 @@ #include #include #include +#include #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h index fdeadf9cdbfa..b468a1c05c6d 100644 --- a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -4594,7 +4594,7 @@ static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, co size_t down = 0, up = (end - beg); while(down < up) { - const size_t mid = (down + up) / 2; + const size_t mid = down + (up - down) / 2; //Overflow-safe midpoint calculation if(cmp(*(beg+mid), key)) { down = mid + 1; diff --git a/aten/src/ATen/native/vulkan/glsl/add.glsl b/aten/src/ATen/native/vulkan/glsl/add.glsl index 9b7e992e78c5..27e69152ac1b 100644 --- a/aten/src/ATen/native/vulkan/glsl/add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add.glsl @@ -1,27 +1,28 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; -layout(set = 0, binding = 3) uniform constBlock { - int W; - int H; - int C; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; float alpha; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 WHC = ivec3(uConstBlock.W, uConstBlock.H, uConstBlock.C); - if (all(lessThan(pos, WHC))) { - vec4 v = texelFetch(uInput0, pos, 0) + - uConstBlock.alpha * texelFetch(uInput1, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput0, pos, 0) + uBlock.alpha * texelFetch(uInput1, pos, 0)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_.glsl b/aten/src/ATen/native/vulkan/glsl/add_.glsl new file mode 100644 index 000000000000..c872a8193ca3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_.glsl @@ -0,0 +1,27 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; + float alpha; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.alpha * texelFetch(uInput0, pos, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl index 559cdd7441c3..10a95330a48c 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl @@ -1,21 +1,27 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; float other; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = texelFetch(uInput, pos, 0) + uConstBlock.other; - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) + uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl new file mode 100644 index 000000000000..8e736e2a6a71 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform restrict Block { + ivec3 WHC; + float other; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl index d34a99d2c6e8..8d7fc2b93a7f 100644 --- a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl @@ -1,21 +1,27 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; float other; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = uConstBlock.other * texelFetch(uInput, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) * uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl new file mode 100644 index 000000000000..9d1626a2ba83 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform restrict Block { + ivec3 WHC; + float other; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) * uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Add.cpp b/aten/src/ATen/native/vulkan/ops/Add.cpp new file mode 100644 index 000000000000..1c1fa216d3f3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Add.cpp @@ -0,0 +1,257 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor add_scalar( + const Tensor& self_arg, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_output.has_image() && v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& add_scalar_( + Tensor& self_arg, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +Tensor add_tensor( + const Tensor& self_arg, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image() && v_other.has_image()) { + const struct { + uint32_t width, height, channels; + float alpha; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& add_tensor_( + Tensor& self_arg, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image() && v_other.has_image()) { + const struct { + uint32_t width, height, channels; + float alpha; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("add.Scalar", TORCH_FN(add_scalar)); + m.impl("add_.Scalar", TORCH_FN(add_scalar_)); + m.impl("add.Tensor", TORCH_FN(add_tensor)); + m.impl("add_.Tensor", TORCH_FN(add_tensor_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h new file mode 100644 index 000000000000..121b40cbdb4b --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -0,0 +1,9 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include +#include + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp new file mode 100644 index 000000000000..2f74d1be00ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -0,0 +1,149 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src) { + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // CPU -> Vulkan + if (at::kCPU == src.device().type()) { + // Requesting write-only host access to the tensor never triggers a sync + // as the contents will be overwritten regardless. Having said that, + // appropriate barriers are inserted automatically if WAR or WAW hazards + // are detected. Examples of such scenario for instance are if any of + // these async operations are on going in the background on 'self': + // - On discrete systems: + // * buffer-to-staging transfers + // * staging-to-buffer transfers + // - On UMA buffer is an alias for staging and accessible both on host + // and device. Consequently: + // * buffer-to-image NHWC -> NC4HW packing + // * image-to-buffer NC4HW -> NHWC unpacking + + using Future = vTensor::Future; + Future v_self_future = v_self.host(); + + // This wait() will be a no-op if no hazards are detected, including the + // obvious, yet important, special case of 'self' being an empty tensor. + + Future::Payload v_self_payload = v_self_future.wait(); + + memcpy( + v_self_payload.get(), + src.contiguous().data_ptr(), + std::min(src.nbytes(), self.nbytes())); + } + // Vulkan -> Vulkan + else if (at::kVulkan == src.device().type()) { + api::Command::Buffer command_buffer = api::context()->command().pool.allocate(); + command_buffer.begin(); + + command_buffer.copy( + // - Read-only access is implied on const tensors. Memory barriers + // are automatically inserted if a RAW hazard is detected. + // - Recording any potential pending sync operations into the same + // command buffer prevents an expensive queue submission. + convert(src).buffer(command_buffer), + // - Write-only access never triggers a sync as the contents will be + // overwritten regardless. Having said that, appropriate barriers + // are inserted automatically if WAR or WAW hazards are detected. + // - Recording pending sync operations into the same command buffer + // prevents an expensive queue submission. + v_self.buffer(command_buffer, vTensor::Access::Write)); + + command_buffer.end(); + command_buffer.submit(api::context()->gpu().queue); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + } + // Vulkan -> X + else if (at::kVulkan == src.device().type()) { + const vTensor& v_src = convert(src); + + { + // Similar notes as above applies, with the additional consideration of + // potential syncs on read accesses. Namely, + // - on discrete systems, if the (staging, buffer, image) trio, or + // - on UMA, if the (buffer, image) duo + // have gone out of sync as a result of one processor writing to one + // resource which is then either accessed as an another resource type on + // the same or another processor. Same considerations regarding hazard + // avoidance as above applies. + + using Future = vTensor::Future; + const Future v_src_future = v_src.host(); + + // Vulkan -> CPU + if (at::kCPU == self.device().type()) { + // This wait() is a no-op if data is not out of sync. More often than + // not though, waits here are expected as the GPU catches up with + // compute submitted from CPU. + + const Future::Payload v_src_payload = v_src_future.wait(); + + memcpy( + self.data_ptr(), + v_src_payload.get(), + std::min(src.nbytes(), self.nbytes())); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + } + + // + // WARNING + // + + // This is not great. We almost never want to flush the GPU pipeline as + // that has far reaching consequences, especially if PyTorch is not the only + // process accessing the GPU. If we have done our job properly, above + // synchronization mechanisms should be enough to ensure correctness at a more + // modest cost, as there is no need to flush the entirety of jobs in flight + // if one is only interested on waiting on computation affecting one single + // tensor to finish. + // + // Having said that, we still do need to release all pool resources at one + // point per inference run or we will run out of memory otherwise. There is + // no perfect answer to this problem that checks all boxes, which leaves us + // with one of several design decisions: + // + // 1) Use graph mode to gain an understanding of the computation graph, + // itself allowing us to place pool purges intelligently. Best option + // for performance and memory consumption. Not without its downsides if + // flexibility is a top priority. + // 2) If on eager mode, and hence are seeing operations one at a time, expose + // this release of resources to the user as a Python / C++ function. This + // makes for suboptimal user experience but is efficient in terms of + // performance. + // 3) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, release all resources somewhere ... like here. This is + // not ideal since it requires a pipeline flush to make sure these objects + // are not already in use by a workload in flight. Cannot do much better + // within the constraints of this approach. Good for user experience, + // suboptimal for performance. + // 4) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, and performance does not matter, make CPU and GPU run in + // lockstep. Obviously this is just bad. Mentioned for the sake of + // completeness. + + api::context()->flush(); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + + return self; +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Copy.h b/aten/src/ATen/native/vulkan/ops/Copy.h new file mode 100644 index 000000000000..e69af06357c5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.h @@ -0,0 +1,19 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp new file mode 100644 index 000000000000..fc58fe46d438 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -0,0 +1,58 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor empty_memory_format( + const IntArrayRef sizes, + const TensorOptions& options_, + const optional memory_format = c10::nullopt) { + TORCH_CHECK( + !(options_.has_memory_format() && memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument!"); + + const TensorOptions options = options_.merge_in( + TensorOptions().memory_format(memory_format)); + verify(options); + + return convert(vTensor{ + api::context(), + sizes, + options, + }); +} + +Tensor empty_strided( + const IntArrayRef sizes, + const IntArrayRef /* strides */, + const optional dtype, + const optional layout, + const optional device, + const optional pin_memory) { + return empty_memory_format( + sizes, + TensorOptions(). + dtype(dtype). + layout(layout). + device(device). + pinned_memory(pin_memory)); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl_UNBOXED("empty.memory_format", at::native::vulkan::ops::empty_memory_format); + m.impl("empty_strided", TORCH_FN(at::native::vulkan::ops::empty_strided)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Mul.cpp b/aten/src/ATen/native/vulkan/ops/Mul.cpp new file mode 100644 index 000000000000..d7bda8f73016 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mul.cpp @@ -0,0 +1,130 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor mul_scalar( + const Tensor& self_arg, + const Scalar other) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_output.has_image() && v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& mul_scalar_( + Tensor& self_arg, + const Scalar other) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("mul.Scalar", TORCH_FN(mul_scalar)); + m.impl("mul_.Scalar", TORCH_FN(mul_scalar_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp new file mode 100644 index 000000000000..a5baf716069f --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -0,0 +1,1245 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +VkDeviceSize bytes( + const IntArrayRef sizes, + const caffe2::TypeMeta dtype) { + VkDeviceSize size = c10::elementSize(c10::typeMetaToScalarType(dtype)); + + // Forward declaration + bool requires_image(IntArrayRef); + + if (requires_image(sizes)) { + // Forward declaration + VkExtent3D image_extents(IntArrayRef); + + const VkExtent3D extents = image_extents(sizes); + size *= extents.width * extents.height * (4u * extents.depth); + } + else { + size *= prod_intlist(sizes); + } + + return size; +} + +VkFormat convert(const caffe2::TypeMeta dtype) { + switch (c10::typeMetaToScalarType(dtype)) { + case kFloat: +#ifdef VULKAN_FP16_INFERENCE + return VK_FORMAT_R16G16B16A16_SFLOAT; +#else + return VK_FORMAT_R32G32B32A32_SFLOAT; +#endif /* VULKAN_FP16_INFERENCE */ + + default: + TORCH_CHECK( + false, + "Vulkan tensor format not supported!"); + } + + return VK_FORMAT_UNDEFINED; +} + +vTensor::Access::Flags convert(const VkAccessFlags vk_access) { + vTensor::Access::Flags access = 0u; + + constexpr VkAccessFlags kRead = + VK_ACCESS_HOST_READ_BIT | + VK_ACCESS_MEMORY_READ_BIT | + VK_ACCESS_SHADER_READ_BIT | + VK_ACCESS_TRANSFER_READ_BIT | + VK_ACCESS_UNIFORM_READ_BIT; + + constexpr VkAccessFlags kWrite = + VK_ACCESS_HOST_WRITE_BIT | + VK_ACCESS_MEMORY_WRITE_BIT | + VK_ACCESS_SHADER_WRITE_BIT | + VK_ACCESS_TRANSFER_WRITE_BIT; + + if (vk_access & kRead) { + access |= vTensor::Access::Read; + } + + if (vk_access & kWrite) { + access |= vTensor::Access::Write; + } + + return access; +} + +vTensor::Buffer allocate_buffer( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options) { + TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); + verify(options); + + // Forward declaration + bool requires_staging(api::Context*); + + const VkFlags usage = [context]() { + VkFlags usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + + if (requires_staging(context)) { + usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT | + VK_BUFFER_USAGE_TRANSFER_DST_BIT; + } + + return usage; + }(); + + const auto memory = [context]() -> api::Resource::Memory::Descriptor { + if (requires_staging(context)) { + return { + VMA_MEMORY_USAGE_GPU_ONLY, + 0u, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + }; + } + + return { + VMA_MEMORY_USAGE_UNKNOWN, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + }; + }(); + + return context->resource().pool.buffer( + vTensor::Buffer::Descriptor{ + bytes(sizes, options.dtype()), + // Usage + { + usage, + memory, + }, + }); +} + +bool requires_image(const IntArrayRef sizes) { + return (1u <= sizes.size()) && (sizes.size() <= 4u); +} + +VkExtent3D image_extents(const IntArrayRef sizes) { + int64_t width = 1; + int64_t height = 1; + int64_t depth = 1; + + switch (sizes.size()) { + case 1: + width = sizes[0]; + break; + + case 2: + width = sizes[1]; + height = sizes[0]; + break; + + case 3: + width = sizes[2]; + height = sizes[1]; + depth = sizes[0]; + break; + + case 4: + width = sizes[3]; + height = sizes[2]; + depth = sizes[0] * sizes[1]; + break; + + default: + TORCH_INTERNAL_ASSERT( + false, + "Only Tensors with 1 <= dim <= 4 can be represented as a Vulkan Image!"); + } + + return { + width, + height, + api::div_up(depth, 4u), + }; +} + +vTensor::Image allocate_image( + api::Context* const context, + const VkExtent3D& extents, + const TensorOptions& options) { + verify(options); + + return context->resource().pool.image( + vTensor::Image::Descriptor{ + VK_IMAGE_TYPE_3D, + convert(options.dtype()), + extents, + // Usage + { + VK_IMAGE_USAGE_SAMPLED_BIT | + VK_IMAGE_USAGE_STORAGE_BIT, + { + VMA_MEMORY_USAGE_GPU_ONLY, + 0u, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, + }, + }, + // View + { + VK_IMAGE_VIEW_TYPE_3D, + convert(options.dtype()), + }, + }); +} + +bool requires_staging(api::Context* const context) { + return !context->gpu().adapter->has_unified_memory(); +} + +vTensor::Buffer allocate_staging( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options) { + TORCH_CHECK(!sizes.empty(), "Invalid Vulkan tensor size!"); + verify(options); + + return context->resource().pool.buffer( + vTensor::Buffer::Descriptor{ + bytes(sizes, options.dtype()), + // Usage + { + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | + VK_BUFFER_USAGE_TRANSFER_DST_BIT, + { + VMA_MEMORY_USAGE_CPU_ONLY, + 0u, + 0u, + }, + }, + }); +} + +vTensor::Fence allocate_fence( + api::Context* const context) { + return context->resource().pool.fence(); +} + +enum class Barrier { + None, + Exectution, + Memory, +}; + +Barrier categorize( + const VkAccessFlags vk_src_access, + const VkAccessFlags vk_dst_access) { + if (0u == vk_src_access) { + return Barrier::None; + } + + const vTensor::Access::Flags src_access = convert(vk_src_access); + const vTensor::Access::Flags dst_access = convert(vk_dst_access); + + if (vTensor::Access::Read == (src_access & vTensor::Access::Read)) { + if (vTensor::Access::Read == (dst_access & vTensor::Access::Read)) { + // RAR (Read after Read) + return Barrier::None; + } + + // WAR (Write after Read) + return Barrier::Exectution; + } + + // RAW (Read after Write), or WAW (Write after Write) + return Barrier::Memory; +}; + +Barrier categorize( + const VkAccessFlags vk_src_access, + const VkAccessFlags vk_dst_access, + const VkImageLayout vk_src_layout, + const VkImageLayout vk_dst_layout) { + if (vk_src_layout != vk_dst_layout) { + return Barrier::Memory; + } + + return categorize(vk_src_access, vk_dst_access); +} + +} // namespace + +vTensor::vTensor( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options) + : view_(new View{ + context, + sizes, + options, + }) { +} + +const vTensor* vTensor::host() const { + view_->staging(Access::Read); + return this; +} + +vTensor* vTensor::host(const Access::Flags access) { + view_->staging(access); + return this; +} + +vTensor::Buffer::Object vTensor::buffer() const & { + return view_->buffer(Access::Read).object; +} + +vTensor::Buffer::Object vTensor::buffer( + const Access::Flags access) & { + return view_->buffer(access).object; +} + +vTensor::Buffer::Object vTensor::buffer( + api::Command::Buffer& command_buffer) const & { + return view_->buffer(command_buffer, Access::Read).object; +} + +vTensor::Buffer::Object vTensor::buffer( + api::Command::Buffer& command_buffer, + const Access::Flags access) & { + return view_->buffer(command_buffer, access).object; +} + +vTensor::Image::Object vTensor::image() const & { + return view_->image(Access::Read).object; +} + +vTensor::Image::Object vTensor::image( + const Access::Flags access) & { + return view_->image(access).object; +} + +vTensor::Image::Object vTensor::image( + api::Command::Buffer& command_buffer) const & { + return view_->image(command_buffer, Access::Read).object; +} + +vTensor::Image::Object vTensor::image( + api::Command::Buffer& command_buffer, + const Access::Flags access) & { + return view_->image(command_buffer, access).object; +} + +vTensor::View::View() + // Resources + : buffer_{}, + image_{}, + staging_{}, + fence_{}, + // Context + context_(nullptr), + // State + state_{}, + // Metadata + extents_{} { +} + +vTensor::View::View( + api::Context* const context, + const IntArrayRef sizes, + const TensorOptions& options) + // Resources + : buffer_{}, + image_{}, + staging_{}, + fence_{}, + // Context + context_(context), + // State + state_(context, sizes), + // Metadata + extents_(image_extents(sizes)), + options_(options), + sizes_(sizes), + strides_(sizes.size()) { + ops::verify(options); +} + +// We typically do not know whether we need a command buffer to service a request +// until we have perfomed a bunch of checks in nested logic, and even then we +// may end up with the always issued state transition optimized away under +// certain conditions, which makes a policy of always allocating a command buffer +// up front, only to end up using it at times, a wasteful approach. This class +// answers that need. + +class vTensor::View::CMD final { + public: + CMD(const View&); + CMD(const View&, api::Command::Buffer&); + CMD(const CMD&) = delete; + CMD& operator=(const CMD&) = delete; + CMD(CMD&&) = delete; + CMD& operator=(CMD&&) = delete; + ~CMD() = default; + + typedef api::Resource::Buffer Buffer; + typedef api::Resource::Image Image; + typedef api::Resource::Fence Fence; + + void barrier(State::Transition transition); + + void copy_buffer_to_staging( + State& state, + const Buffer::Object& buffer, + Buffer::Object& staging); + + void copy_staging_to_buffer( + State& state, + const Buffer::Object& staging, + Buffer::Object& buffer); + + void copy_buffer_to_image( + State& state, + const Buffer::Object& buffer, + Image::Object& image); + + void copy_image_to_buffer( + State& state, + const Image::Object& image, + Buffer::Object& buffer); + + void submit(Fence fence = {}); + + private: + api::Command::Buffer& command_buffer(); + + private: + const View& view_; + + enum class Type { + Internal, + External, + } type; + + union { + api::Command::Buffer internal; + api::Command::Buffer* external; + } command_buffer_; +}; + +vTensor::View::CMD::CMD( + const View& view) + : view_(view), + type(Type::Internal), + command_buffer_{} { +} + +vTensor::View::CMD::CMD( + const View& view, + api::Command::Buffer& external) + : view_(view), + type(Type::External), + command_buffer_{ + .external = &external, + } { +} + +api::Command::Buffer& vTensor::View::CMD::command_buffer() { + switch (type) { + case Type::Internal: + if (!command_buffer_.internal) { + command_buffer_.internal = view_.context_->command().pool.allocate(); + command_buffer_.internal.begin(); + } + + return command_buffer_.internal; + + case Type::External: + return *(command_buffer_.external); + + default: + TORCH_INTERNAL_ASSERT(false, "Unknown command buffer type!"); + break; + } +} + +void vTensor::View::CMD::barrier(State::Transition transition) { + // Buffer and Staging are just an alias for the same memory location on UMA. + + if (view_.state_.is_uma()) { + transition.first.buffer.stage |= transition.first.staging.stage; + transition.first.buffer.access |= transition.first.staging.access; + transition.first.staging = {}; + + transition.second.buffer.stage |= transition.second.staging.stage; + transition.second.buffer.access |= transition.second.staging.access; + transition.second.staging = {}; + } + + // Filter out host dependencies out of source, per Vulkan spec host write ordering guarantees: + // https://www.khronos.org/registry/vulkan/specs/1.2/html/vkspec.html#synchronization-submission-host-writes + + const auto filter_stage =[](VkPipelineStageFlags& stage) { + stage &= ~VK_PIPELINE_STAGE_HOST_BIT; + }; + + filter_stage(transition.first.buffer.stage); + filter_stage(transition.first.staging.stage); + + const auto filter_access =[](VkAccessFlags& access) { + access &= ~(VK_ACCESS_HOST_READ_BIT | VK_ACCESS_HOST_WRITE_BIT); + }; + + filter_access(transition.first.buffer.access); + filter_access(transition.first.staging.access); + + api::Pipeline::Barrier barrier{}; + + if (transition.second.staging) { + const State::Bundle::Buffer from = transition.first.staging; + const State::Bundle::Buffer to = transition.second.staging; + + const Barrier category = categorize( + from.access, + to.access); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + barrier.buffers.push_back({ + view_.staging().object, + { + from.access, + to.access, + }, + }); + } + } + } + + if (transition.second.buffer) { + const State::Bundle::Buffer from = transition.first.buffer; + const State::Bundle::Buffer to = transition.second.buffer; + + const Barrier category = categorize( + from.access, + to.access); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + barrier.buffers.push_back({ + view_.buffer().object, + { + from.access, + to.access, + }, + }); + } + } + } + + if (transition.second.image) { + const State::Bundle::Image from = transition.first.image; + const State::Bundle::Image to = transition.second.image; + + const Barrier category = categorize( + from.access, + to.access, + from.layout, + to.layout); + + if (Barrier::None != category) { + barrier.stage.src |= from.stage; + barrier.stage.dst |= to.stage; + + if (Barrier::Memory == category) { + TORCH_INTERNAL_ASSERT( + from.layout == view_.image().object.layout, + "Invalid image layout!"); + + barrier.images.push_back({ + view_.image().object, + { + from.access, + to.access, + }, + { + from.layout, + to.layout, + }, + }); + + view_.image().object.layout = to.layout; + } + } + } + + // If we are left with anything meaningful, insert a barrier. + + if (barrier) { + if (0u == barrier.stage.src) { + barrier.stage.src = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; + } + + if (0u == barrier.stage.dst) { + barrier.stage.src = VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT; + } + + // Optimization opportunity: delay and batch. + + command_buffer().barrier(barrier); + } +} + +void vTensor::View::CMD::copy_buffer_to_staging( + State& state, + const Buffer::Object& buffer, + Buffer::Object& staging) { + if (state.is_clean(Component::Staging) || state.is_uma()) { + return; + } + + barrier( + state.transition({ + // Staging + { + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_ACCESS_TRANSFER_WRITE_BIT, + }, + // Buffer + { + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_ACCESS_TRANSFER_READ_BIT, + }, + // Image + {}, + })); + + command_buffer().copy(buffer, staging); +} + +void vTensor::View::CMD::copy_staging_to_buffer( + State& state, + const Buffer::Object& staging, + Buffer::Object& buffer) { + if (state.is_clean(Component::Buffer) || state.is_uma()) { + return; + } + + barrier( + state.transition({ + // Staging + { + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_ACCESS_TRANSFER_READ_BIT, + }, + // Buffer + { + VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_ACCESS_TRANSFER_WRITE_BIT, + }, + // Image + {}, + })); + + command_buffer().copy(staging, buffer); +} + +void vTensor::View::CMD::copy_buffer_to_image( + State& state, + const Buffer::Object& buffer, + Image::Object& image) { + if (state.is_clean(Component::Image)) { + return; + } + + barrier( + state.transition({ + // Staging + {}, + // Buffer + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_ACCESS_SHADER_READ_BIT, + }, + // Image + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_ACCESS_SHADER_WRITE_BIT, + VK_IMAGE_LAYOUT_GENERAL, + }, + })); + + const struct { + uint32_t width; + uint32_t height; + } block { + view_.extents().width, + view_.extents().height, + }; + + view_.context_->dispatch( + command_buffer(), + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(nchw_to_image), + view_.extents(), + image, + buffer, + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + view_.context_->resource().pool.uniform(block).object); +} + +void vTensor::View::CMD::copy_image_to_buffer( + State& state, + const Image::Object& image, + Buffer::Object& buffer) { + if (state.is_clean(Component::Buffer)) { + return; + } + + barrier( + state.transition({ + // Staging + {}, + // Buffer + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_ACCESS_SHADER_WRITE_BIT, + }, + // Image + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + VK_ACCESS_SHADER_READ_BIT, + VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, + }, + })); + + const struct { + uint32_t width; + uint32_t height; + } block { + view_.extents().width, + view_.extents().height, + }; + + view_.context_->dispatch( + command_buffer(), + { + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(image_to_nchw), + view_.extents(), + image, + buffer, + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + view_.context_->resource().pool.uniform(block).object); +} + +void vTensor::View::CMD::submit(const api::Resource::Fence fence) { + if ((Type::Internal == type) && command_buffer_.internal) { + command_buffer_.internal.end(); + command_buffer_.internal.submit(view_.context_->gpu().queue, fence); + } +} + +vTensor::Buffer& vTensor::View::buffer() const { + if (!buffer_) { + buffer_ = allocate_buffer( + context_, + sizes(), + options()); + } + + return buffer_; +} + +vTensor::Buffer& vTensor::View::buffer( + const Access::Flags access) const { + CMD command_buffer(*this); + Buffer& buffer = this->buffer(command_buffer, access); + command_buffer.submit(); + + return buffer; +} + +vTensor::Buffer& vTensor::View::buffer( + api::Command::Buffer& command_buffer_, + const Access::Flags access) const { + CMD command_buffer(*this, command_buffer_); + return buffer(command_buffer, access); +} + +vTensor::Buffer& vTensor::View::buffer( + CMD& command_buffer, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Buffer)) { + if (state_.is_clean(Component::Staging)) { + command_buffer.copy_staging_to_buffer( + state_, + staging(command_buffer, Access::Read).object, + buffer().object); + } + else if (state_.is_clean(Component::Image)) { + command_buffer.copy_image_to_buffer( + state_, + image(command_buffer, Access::Read).object, + buffer().object); + } + else { + TORCH_INTERNAL_ASSERT( + false, + "Invalid state!"); + } + } + + command_buffer.barrier( + state_.transition({ + // Staging + {}, + // Buffer + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + [access]() { + VkAccessFlags vk_access = 0u; + + if (access & Access::Read) { + vk_access |= VK_ACCESS_SHADER_READ_BIT; + } + + if (access & Access::Write) { + vk_access |= VK_ACCESS_SHADER_WRITE_BIT; + } + + return vk_access; + }(), + }, + // Image + {}, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Buffer); + + return buffer(); +} + +vTensor::Image& vTensor::View::image() const { + if (!image_ && state_.is_available(Component::Image)) { + image_ = allocate_image( + context_, + extents(), + options()); + } + + return image_; +} + +vTensor::Image& vTensor::View::image( + const Access::Flags access) const { + CMD command_buffer(*this); + Image& image = this->image(command_buffer, access); + command_buffer.submit(); + + return image; +} + +vTensor::Image& vTensor::View::image( + api::Command::Buffer& command_buffer_, + const Access::Flags access) const { + CMD command_buffer(*this, command_buffer_); + return image(command_buffer, access); +} + +vTensor::Image& vTensor::View::image( + CMD& command_buffer, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Image)) { + command_buffer.copy_buffer_to_image( + state_, + buffer(command_buffer, Access::Read).object, + image().object); + } + + command_buffer.barrier( + state_.transition({ + // Staging + {}, + // Buffer + {}, + // Image + { + VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, + [access]() { + VkAccessFlags vk_access = 0u; + + if (access & Access::Read) { + vk_access |= VK_ACCESS_SHADER_READ_BIT; + } + + if (access & Access::Write) { + vk_access |= VK_ACCESS_SHADER_WRITE_BIT; + } + + return vk_access; + }(), + [access]() { + if (Access::Read == (access & Access::Read)) { + return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL; + } + + return VK_IMAGE_LAYOUT_GENERAL; + }(), + }, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Image); + + return image(); +} + +vTensor::Buffer& vTensor::View::staging() const { + if (!state_.is_available(Component::Staging)) { + return buffer(); + } + + if (!staging_) { + staging_ = allocate_staging( + context_, + sizes(), + options()); + } + + return staging_; +} + +vTensor::Buffer& vTensor::View::staging(const Access::Flags access) const { + CMD command_buffer(*this); + Buffer& staging = this->staging(command_buffer, access); + command_buffer.submit(fence()); + + return staging; +} + +vTensor::Buffer& vTensor::View::staging( + CMD& command_buffer, + const Access::Flags access) const { + if ((access & Access::Read) && state_.is_dirty(Component::Staging)) { + command_buffer.copy_buffer_to_staging( + state_, + buffer(command_buffer, Access::Read).object, + staging().object); + } + + command_buffer.barrier( + state_.transition({ + // Staging + { + VK_PIPELINE_STAGE_HOST_BIT, + [access]() { + VkAccessFlags vk_access = 0u; + + if (access & Access::Read) { + vk_access |= VK_ACCESS_HOST_READ_BIT; + } + + if (access & Access::Write) { + vk_access |= VK_ACCESS_HOST_WRITE_BIT; + } + + return vk_access; + }(), + }, + // Buffer + {}, + // Image + {}, + })); + + if (access & Access::Write) { + state_.set_dirty(Component::All); + } + + state_.set_clean(Component::Staging); + + return staging(); +} + +vTensor::Memory& vTensor::View::wait() const { + if (fence_) { + fence_.wait(); + } + + return staging().memory; +} + +vTensor::Fence& vTensor::View::fence() const { + return (fence_ = allocate_fence(context_)); +} + +void vTensor::View::verify() const { + TORCH_INTERNAL_ASSERT(!image_ || state_.is_available(Component::Image)); + TORCH_INTERNAL_ASSERT(!staging_ || state_.is_discrete()); +} + +vTensor::View::State::State() + : available_{}, + dirty_{}, + bundle_{} { +} + +vTensor::View::State::State( + api::Context* const context, + const IntArrayRef sizes) + : available_{}, + dirty_{}, + bundle_{} { + available_ |= Component::Buffer; + + if (requires_image(sizes)) { + available_ |= Component::Image; + } + + if (requires_staging(context)) { + available_ |= Component::Staging; + } +} + +vTensor::View::State::Transition +vTensor::View::State::transition(const Bundle bundle) { + const Bundle from = bundle_; + Bundle& to = bundle_; + + if (bundle.staging) { + to.staging = bundle.staging; + } + + if (bundle.buffer) { + to.buffer = bundle.buffer; + } + + if (bundle.image) { + to.image = bundle.image; + } + +#ifdef DEBUG + // Forward declaration + std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); + + std::cout << "From:" << std::endl << from << std::endl; + std::cout << "To:" << std::endl << to << std::endl; +#endif /* DEBUG */ + + return Transition{ + from, + to, + }; +} + +void verify(const TensorOptions& options) { + TORCH_CHECK( + !options.has_requires_grad() || !options.requires_grad(), + "'requires_grad' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_pinned_memory() || !options.pinned_memory(), + "'pinned_memory' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_layout() || (c10::kStrided == options.layout()), + "'layout' tensor option is not yet supported under Vulkan!"); + + TORCH_CHECK( + !options.has_memory_format(), + "'memory_format' tensor option is not yet supported under Vulkan!"); +} + +// +// Debug +// + +namespace { + +// Considering that VkAccessFlags is a weak typedef of a built-in data type, we +// need to introduce a new type to allow overload resolution distinguish between +// the two. + +struct Access final { + VkAccessFlags value; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Access& access) { + stream << "Access: "; + + if (0u == access.value) { + return stream << " 0"; + } + + if (access.value & VK_ACCESS_HOST_READ_BIT) { + stream << " VK_ACCESS_HOST_READ_BIT"; + } + + if (access.value & VK_ACCESS_HOST_WRITE_BIT) { + stream << " VK_ACCESS_HOST_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_MEMORY_READ_BIT) { + stream << " VK_ACCESS_MEMORY_READ_BIT"; + } + + if (access.value & VK_ACCESS_MEMORY_WRITE_BIT) { + stream << " VK_ACCESS_MEMORY_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_SHADER_READ_BIT) { + stream << " VK_ACCESS_SHADER_READ_BIT"; + } + + if (access.value & VK_ACCESS_SHADER_WRITE_BIT) { + stream << " VK_ACCESS_SHADER_WRITE_BIT"; + } + + if (access.value & VK_ACCESS_TRANSFER_READ_BIT) { + stream << " VK_ACCESS_TRANSFER_READ_BIT"; + } + + if (access.value & VK_ACCESS_TRANSFER_WRITE_BIT) { + stream << " VK_ACCESS_TRANSFER_WRITE_BIT"; + } + + return stream; +} + +// Considering that VkImageLayout is a weak typedef of a built-in data type, +// we need to introduce a new type to allow overload resolution distinguish +// between the two. + +struct Layout final { + VkImageLayout value; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Layout& layout) { + stream << "Layout: "; + + switch (layout.value) { + case VK_IMAGE_LAYOUT_UNDEFINED: + stream << " VK_IMAGE_LAYOUT_UNDEFINED"; + break; + + case VK_IMAGE_LAYOUT_GENERAL: + stream << " VK_IMAGE_LAYOUT_GENERAL"; + break; + + case VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL: + stream << " VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL"; + break; + + default: + stream << " Unknown!"; + break; + }; + + return stream; +} + +// Considering that VkPipelineStageFlags is a weak typedef of a built-in data +// type, we need to introduce a new type to allow overload resolution distinguish +// between the two. + +struct Stage final { + VkPipelineStageFlags value; +}; + +std::ostream& operator<<( + std::ostream& stream, + const Stage& stage) { + stream << "Stage: "; + + if (0u == stage.value) { + return stream << " 0"; + } + + if (stage.value & VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT) { + stream << " VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_HOST_BIT) { + stream << " VK_PIPELINE_STAGE_HOST_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT) { + stream << " VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT"; + } + + if (stage.value & VK_PIPELINE_STAGE_TRANSFER_BIT) { + stream << " VK_PIPELINE_STAGE_TRANSFER_BIT"; + } + + return stream; +} + +} // namespace + +std::ostream& operator<<( + std::ostream& stream, + const vTensor::View::State::Bundle& bundle) { + stream << "Staging\n " << + Stage{ + bundle.staging.stage, + } << "\n " << + Access{ + bundle.staging.access, + } << std::endl; + + stream << "Buffer\n " << + Stage{ + bundle.buffer.stage, + } << "\n " << + Access{ + bundle.buffer.access, + } << std::endl; + + stream << "Image\n " << + Stage{ + bundle.image.stage, + } << "\n " << + Access{ + bundle.image.access, + } << "\n " << + Layout{ + bundle.image.layout, + } << std::endl; + + return stream; +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h new file mode 100644 index 000000000000..08eede9a4f18 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -0,0 +1,598 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +// +// This class represents a Vulkan tensor and provides an abstraction layer +// that allows both the CPU, and the GPU, to view a Vulkan (buffer, image) +// pair as one coherent, synchronized unit of storage on both UMA and discrete +// systems. Expanding on the previous sentence, this class tries to address +// two orthogonal implementation complexities that arise as a result of the +// aforementioned goal of memory coherence: +// +// 1) First, synchronization across processors; CPUs and GPUs are separate +// processors, and even though they share the same address space in a system +// with a unified memory architecture, their address spaces only partially +// overlap on systems with a discrete GPU. Consequently on discrete systems, +// while it is still technically possible to take advantage of this shared +// address space to maintain one single copy of the data, different access +// latencies from CPU and GPU to this shared location usually necessitates +// maintaining two copies each in processor-local memory, otherwise memory +// access latency will hurt from the processor to which this data is not +// close. This shared memory is more often than not located in system memory, +// making for slow GPU read and write access over the PCI-e bus on discrete. +// Maintaining two separate copies on the other hand, requires synchronization +// to guarantee coherence. This is not an issue on UMA and this implementation +// accounts for that optimization. +// +// 2) Second, synchronization across resources (i.e. buffers and images); GPU +// drivers pack images in proprietory formats for better locality of access +// and to enable lossless compression. These conversions are both expensive +// (in general) and manual (in Vulkan.) This requires a second order of +// synchronization to guarantee coherence between the contents of the buffer +// and image otherwise they will go out of sync. +// +// It is extremely important to keep in mind that the functionality this class +// provides is generally expensive. For optimal performance, the user of this +// class should: +// +// 1) Avoid frequent CPU <=> GPU transfers which will be triggered if data is +// write accessed on one processor and read / write accessed on the other. +// +// 2) Avoid frequent buffer <=> image conversions which will be trigerred if +// data is write accessed as a buffer (image) and read accessed as an +// image (buffer). +// +// 3) When and if a synchronization is unavoidable, place as much distance +// between the synchronization is triggered and the data is accessed since +// all synchronizations this class provides are async. +// +// For optimal performance, access the data as images, and keep the data on GPU, +// and above all understand the expensive data flow that this class abstracts +// away. +// +// vTensor tries to address a specific concern and intentionally does not expose +// GPU tensor memory directly. Please keep that behavior intact as the whole +// data model fundamentally depends on limiting what the user can achieve through +// the interface to guarantee performance and coherence. +// +// A vTensor is associated with an api::Context as preparation for multi-GPU +// support. +// + +class vTensor final { + public: + vTensor() = default; + vTensor( + api::Context* context, + IntArrayRef sizes, + const TensorOptions& options); + + /* + Types + */ + + typedef api::Resource::Memory::Access Access; + typedef api::Resource::Buffer Buffer; + typedef api::Resource::Fence Fence; + typedef api::Resource::Image Image; + typedef api::Resource::Memory Memory; + + /* + Future + */ + + template + class Future final { + template + using is_convertible = std::enable_if_t< + std::is_convertible< + Access::Pointer, + Access::Pointer>::value>; + + public: + explicit Future(const vTensor* tensor); + Future(const Future&) = delete; + Future& operator=(const Future&) = delete; + Future(Future&&); + Future& operator=(Future&&) &; + Future& operator=(Future&&) && = delete; + template> + Future(Future&&); + template> + Future& operator=(Future&&) &; + template + Future& operator=(Future&&) && = delete; + ~Future(); + + typedef Memory::Handle< + Access::Pointer< + Type, + kAccess>> Payload; + + // This is a blocking operation as the name suggests. A call to host() will + // trigger an async copy if pending writes are detected. Consequently, for + // optimal performance, put as much time and distance between the place + // where a vTensor::host() call occurs and the location where the returned + // future is explicitly waited on as a result of a call to this function. + + Payload wait() const &; + + private: + // Intentionally disabed to enforce a usage pattern wherein the Future's + // lifetime exceeds that of the Payload as we use the Future's destructor + // to eagerly (as opposed to lazily and upon first use) upload the + // modifications back onto the GPU in an effort to hide the upload latency. + + Payload wait() const && = delete; + + private: + template + friend class Future; + + private: + const vTensor* tensor_; + }; + + /* + Host access - these functions will be expensive if they trigger a GPU -> CPU + sync due to pending writes. A call to host() will trigger an async copy in + such scenarios, which is then explictly waited on as part of Future::wait(). + Consequently, for optimal performance, put as much time and distance between + the place where this function is called, and the location where the future is + waited on. + */ + + template + Future host() const &; + + template + Future host() &; + + /* + Device access - these functions will be expensive if they trigger a buffer + <-> image or CPU -> GPU sync due to pending writes. These functions are + non-blocking on the host as the copy operation is carried out by the GPU + asynchronously. Regardless, they result in extra work that could have been + avoided or at least minimized if all data access had occured through one + single processor (GPU in this case) and on one type of resource (image for + best performance.) Consequently, for optimal performance, avoid mixed reads + and writes across processor boundaries, and do your best to minimize layout + transitions as a result of working with images only (as opposed to mixed + buffer - image usage.) + This implementation intentionally restricts user access to the buffer and + image objects only, as opposed to their underlying memory, for the sake of + predictability of usage and efficiency. + */ + + Buffer::Object buffer() const &; + Buffer::Object buffer(Access::Flags access) &; + Buffer::Object buffer(api::Command::Buffer&) const &; + Buffer::Object buffer(api::Command::Buffer&, Access::Flags) &; + + bool has_image() const; + Image::Object image() const &; + Image::Object image(Access::Flags access) &; + Image::Object image(api::Command::Buffer&) const &; + Image::Object image(api::Command::Buffer&, Access::Flags) &; + + /* + Metadata + */ + + const VkExtent3D& extents() const; + const TensorOptions& options() const; + IntArrayRef sizes() const; + IntArrayRef strides() const; + + private: + // Some overloads below are intentionally disabled to enforce a usage pattern + // that ensures the Tensor's lifetime exceeds that of the scope in which the + // underlying data is accessed. Allowing deleted overloads below to be + // invoked on a temporary would open the door to the possibility of accessing + // the underlying memory out of the expected scope. + + /* + Host + */ + + const vTensor* host() const; + vTensor* host(Access::Flags access); + + template + Future host() const && = delete; + + template + Future host() && = delete; + + /* + Device + */ + + Buffer::Object buffer() const && = delete; + Buffer::Object buffer(Access::Flags) && = delete; + Buffer::Object buffer(api::Command::Buffer&) const && = delete; + Buffer::Object buffer(api::Command::Buffer&, Access::Flags) && = delete; + + Image::Object image() const && = delete; + Image::Object image(Access::Flags) && = delete; + Image::Object image(api::Command::Buffer&) const && = delete; + Image::Object image(api::Command::Buffer&, Access::Flags) && = delete; + + private: + class View final { + public: + View(); + View( + api::Context* context, + IntArrayRef sizes, + const TensorOptions& options); + View(const View&) = delete; + View& operator=(const View&) = delete; + View(View&&) = default; + View operator=(View&&) = delete; + ~View() = default; + + Buffer& buffer(Access::Flags) const; + Buffer& buffer(api::Command::Buffer&, Access::Flags) const; + + bool has_image() const; + Image& image(Access::Flags) const; + Image& image(api::Command::Buffer&, Access::Flags) const; + + Buffer& staging(Access::Flags) const; + Buffer& staging(api::Command::Buffer&, Access::Flags) const; + vTensor::Memory& wait() const; + + const VkExtent3D& extents() const; + const TensorOptions& options() const; + IntArrayRef sizes() const; + IntArrayRef strides() const; + + private: + class CMD; + + class State final { + public: + State(); + State(api::Context*, IntArrayRef); + + struct Bundle final { + struct Buffer final { + VkPipelineStageFlags stage; + VkAccessFlags access; + + operator bool() const; + } staging, buffer; + + struct Image final { + VkPipelineStageFlags stage; + VkAccessFlags access; + VkImageLayout layout; + + operator bool() const; + } image; + }; + + struct Component final { + typedef uint8_t Flags; + + enum Type : Flags { + Buffer = 1u << 0u, + Image = 1u << 1u, + Staging = 1u << 2u, + All = Buffer | Image | Staging, + }; + }; + + // Availability + bool is_available(Component::Flags) const; + bool is_discrete() const; + bool is_uma() const; + + // Clean / Dirty + bool is_clean(Component::Flags) const; + bool is_dirty(Component::Flags) const; + void set_clean(Component::Flags); + void set_dirty(Component::Flags); + + // Transition + typedef std::pair Transition; + Transition transition(Bundle to); + + private: + Component::Flags available_; + Component::Flags dirty_; + Bundle bundle_; + }; + + typedef State::Component Component; + + private: + // Accessors / Lazy Allocation + Buffer& buffer() const; + Buffer& buffer(CMD&, Access::Flags) const; + Image& image() const; + Image& image(CMD&, Access::Flags) const; + Buffer& staging() const; + Buffer& staging(CMD&, Access::Flags) const; + Fence& fence() const; + + // Validation + void verify() const; + + private: + // Resources + mutable Buffer buffer_; + mutable Image image_; + mutable Buffer staging_; + mutable Fence fence_; + + // Context + api::Context* context_; + + // State + mutable State state_; + + // Metadata + VkExtent3D extents_; + TensorOptions options_; + c10::SmallVector sizes_; + c10::SmallVector strides_; + + private: + // Debug + friend std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); + }; + + // Even at the cost of a heap allocation plus the resulting negative impact + // on cache locality due to the subsequent pointer chasing, it is still + // critcal to share the view across vTensor implementations to minimize + // programmer errors. Ideally this class should have been only made movable, + // and non-copyable - something we cannot do unfortunately due to the inner + // workings of at::TensorImpl requiring copy semantics in + // at::TensorImpl::release_resources() to function as expected. Now that this + // class is made copyable though, a new door to a whole new class of bugs is + // opened, in that there now is a chance of two [shallow] copies, have their + // State objects go out of sync as a result of an operation being performed on + // one shallow copy that is not reflected in the other. Technically, if the + // programmer is very careful, it is possible to avoid this trap and not pay + // the cost of indirection, but the resulting bugs of missing memory barriers + // will be so frustrating to hunt down for those unfamiliar with the internal + // mechanics of this class, that I decided to take the performance pentalty + // of this extra layer of indirection in favor of making this class easier + // to use. + + std::shared_ptr view_; + + private: + // Debug + friend std::ostream& operator<<( + std::ostream&, + const View::State::Bundle&); +}; + +const vTensor& convert(const Tensor& tensor); +vTensor& convert(Tensor& tensor); +Tensor convert(const vTensor& tensor); + +using vTensorImpl = VulkanOpaqueTensorImpl; +void verify(const TensorOptions& options); + +// +// Impl +// + +template +inline vTensor::Future::Future( + const vTensor* const tensor) + : tensor_(tensor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + tensor_, + "Invalid Vulkan tensor!"); +} + +template +inline vTensor::Future::Future( + Future&& future) + : tensor_(std::move(future.tensor_)) { + future.tensor_ = nullptr; +} + +template +inline vTensor::Future& +vTensor::Future::operator=( + Future&& future) & { + tensor_ = std::move(future.tensor_); + future.tensor_ = nullptr; + return *this; +} + +template +template +inline vTensor::Future::Future( + Future&& future) + : tensor_(std::move(future.tensor_)) { + future.tensor_ = nullptr; +} + +template +template +inline vTensor::Future& +vTensor::Future::operator=( + Future&& future) & { + tensor_ = std::move(future.tensor_); + future.tensor_ = nullptr; + return *this; +} + +template +inline vTensor::Future::~Future() { +#if VULKAN_SYNC_TENSORS_EAGERLY + // Sync eagerly in an effort to hide latency. + // Upside: Kick off the async transfer early on to keep the GPU busy. + // Downside: An extra CPU command submission. + if (tensor_ && (Access::Write & kAccess)) { + if (tensor_->has_image()) { + tensor_->image(); + } + else { + tensor_->buffer(); + } + } +#endif +} + +template +inline typename vTensor::Future::Payload +vTensor::Future::wait() const & { + TORCH_CHECK( + tensor_, + "vTensor::Future is in an invalid state! " + "Potential reason: This future is moved from."); + + return tensor_->view_->wait().template map(); +} + +template +inline vTensor::Future vTensor::host() const & { + return Future(host()); +} + +template +inline vTensor::Future vTensor::host() & { + return Future(host(kAccess)); +} + +inline bool vTensor::has_image() const { + return view_->has_image(); +} + +inline const VkExtent3D& vTensor::extents() const { + return view_->extents(); +} + +inline const TensorOptions& vTensor::options() const { + return view_->options(); +} + +inline IntArrayRef vTensor::sizes() const { + return view_->sizes(); +} + +inline IntArrayRef vTensor::strides() const { + return view_->strides(); +} + +inline bool vTensor::View::has_image() const { + return state_.is_available(View::Component::Image); +} + +inline const VkExtent3D& vTensor::View::extents() const { + return extents_; +} + +inline const TensorOptions& vTensor::View::options() const { + return options_; +} + +inline IntArrayRef vTensor::View::sizes() const { + return sizes_; +} + +inline IntArrayRef vTensor::View::strides() const { + return strides_; +} + +inline vTensor::View::State::Bundle::Buffer::operator bool() const { + return (0u != stage) && + (0u != access); +} + +inline vTensor::View::State::Bundle::Image::operator bool() const { + return (0u != stage) && + (0u != access) && + (VK_IMAGE_LAYOUT_UNDEFINED != layout); +} + +inline bool vTensor::View::State::is_available( + const Component::Flags components) const { + return available_ & components; +} + +inline bool vTensor::View::State::is_discrete() const { + return is_available(Component::Staging); +} + +inline bool vTensor::View::State::is_uma() const { + return !is_discrete(); +} + +inline bool vTensor::View::State::is_clean( + const Component::Flags components) const { + return !is_dirty(components); +} + +inline bool vTensor::View::State::is_dirty( + const Component::Flags components) const { + return dirty_ & components; +} + +inline void vTensor::View::State::set_clean( + const Component::Flags components) { + dirty_ &= ~components; +} + +inline void vTensor::View::State::set_dirty( + const Component::Flags components) { + dirty_ |= components; +} + +inline const vTensor& convert(const Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + tensor.is_vulkan(), + "Vulkan tensor expected!"); + + const vTensorImpl* const impl = + static_cast(tensor.unsafeGetTensorImpl()); + + return impl->opaque_handle(); +} + +inline vTensor& convert(Tensor& tensor) { + TORCH_INTERNAL_ASSERT( + tensor.is_vulkan(), + "Vulkan tensor expected!"); + + vTensorImpl* const impl = + static_cast(tensor.unsafeGetTensorImpl()); + + return impl->unsafe_opaque_handle(); +} + +inline Tensor convert(const vTensor& tensor) { + return at::detail::make_tensor( + DispatchKeySet(DispatchKey::Vulkan), + tensor.options().dtype(), + at::Device(at::kVulkan), + tensor, + tensor.sizes(), + tensor.strides()); +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp index 40ecbf6d5720..1c79ac186c1a 100644 --- a/aten/src/ATen/quantized/QTensorImpl.cpp +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -5,7 +5,7 @@ namespace at { QTensorImpl::QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer) : TensorImpl(std::move(storage), key_set, data_type), quantizer_(quantizer) {} diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index c2728c7aab46..efce432d5863 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -18,7 +18,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer); // TODO: Expose in PyTorch Frontend diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 26e9fd9f21fa..41f31968688d 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -376,6 +377,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) { name_ = StringView(name); sequence_nr_ = sequence_nr; thread_id_ = currentThreadId(); + operator_name_.reset(); manager().runStartCallbacks(*this); } @@ -387,6 +389,21 @@ void RecordFunction::before(std::string name, int64_t sequence_nr) { name_ = StringView(std::move(name)); sequence_nr_ = sequence_nr; thread_id_ = currentThreadId(); + operator_name_.reset(); + + manager().runStartCallbacks(*this); +} + +void RecordFunction::before( + c10::OperatorHandle const& op, + int64_t sequence_nr) { + if (!active) { + return; + } + sequence_nr_ = sequence_nr; + thread_id_ = currentThreadId(); + operator_name_ = op.operator_name(); + name_ = StringView(op.schema().name()); manager().runStartCallbacks(*this); } diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index cf839ad4a188..db2ee221a09a 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -1,12 +1,18 @@ #pragma once #include -#include +#include #include +#include +#include #include #include +namespace c10 { +class CAFFE2_API OperatorHandle; +} + namespace at { // Kind of record function scope; @@ -147,6 +153,7 @@ struct TORCH_API RecordFunction { // start callbacks void before(const char* name, int64_t sequence_nr = -1); void before(std::string name, int64_t sequence_nr = -1); + void before(c10::OperatorHandle const& op, int64_t sequence_nr = -1); // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); @@ -178,6 +185,10 @@ struct TORCH_API RecordFunction { return handle_; } + inline c10::optional operator_name() const { + return operator_name_; + } + inline void setHandle(RecordFunctionHandle handle) { handle_ = handle; } @@ -213,6 +224,8 @@ struct TORCH_API RecordFunction { int64_t sequence_nr_ = -1; std::vector inputs_; + c10::optional operator_name_; + // Kind of scope this RecordFunction is observing const RecordScope scope_; diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index 7c9aa96f6e70..589121af07ef 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -3,13 +3,7 @@ #include #include -#include -#include -#include #include -#ifdef USE_VULKAN -#include -#endif namespace at { diff --git a/aten/src/ATen/templates/SchemaRegister.cpp b/aten/src/ATen/templates/SchemaRegister.cpp deleted file mode 100644 index f48e732f4760..000000000000 --- a/aten/src/ATen/templates/SchemaRegister.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// ${generated_comment} - -#include -#include - -using namespace at; - -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { - ${schema_registrations} -} diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp deleted file mode 100644 index b0a4fed24a63..000000000000 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// required for old g++ to compile PRId64 macros, see -// https://github.com/pytorch/pytorch/issues/3571 -// for context -#define __STDC_FORMAT_MACROS - -#include - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -$extra_cuda_headers - -namespace at { - -namespace ${Type} { - -${type_derived_method_definitions} - -} // namespace ${Type} - -TORCH_LIBRARY_IMPL(aten, ${Backend}, m) { - ${function_registrations}; -} - -} // namespace at diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 3f6292f41178..c6067a79a388 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/TypeDefault.cpp index 145a5b421019..4cd1d1586d6a 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/TypeDefault.cpp @@ -64,6 +64,10 @@ TORCH_LIBRARY(aten, m) { m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); } +TORCH_LIBRARY_IMPL(aten, Math, m) { + ${math_function_registrations}; +} + TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { ${default_backend_function_registrations}; } diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index d65c13ae8d97..3275ab76ef62 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -5,12 +5,9 @@ #define __STDC_FORMAT_MACROS #endif -#include - // ${generated_comment} -$storage_tensor_headers -#include +#include #include #include #include diff --git a/aten/src/ATen/templates/TypeDerived.h b/aten/src/ATen/templates/TypeDerived.h deleted file mode 100644 index 4b571f40383f..000000000000 --- a/aten/src/ATen/templates/TypeDerived.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include - -$extra_cuda_headers - -namespace c10 { -struct Storage; -} - -namespace at { - -class Tensor; -using TensorList = ArrayRef; - -class Context; -struct Generator; - -struct Quantizer; -// This is temporary typedef to enable Quantizer in aten native function API -// we'll remove them when we are actually exposing Quantizer class -// to frontend -using ConstQuantizerPtr = const c10::intrusive_ptr&; - -namespace ${Type} { - ${type_derived_method_declarations} -} - -} // namespace at diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 4268db33fa16..385c87f9e7e7 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -76,9 +76,13 @@ list(APPEND ATen_HIP_TEST_SRCS # ${CMAKE_CURRENT_SOURCE_DIR}/hip/hip_stream_test.cpp list(APPEND ATen_VULKAN_TEST_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_test.cpp) +if(USE_VULKAN_API) + list(APPEND ATen_VULKAN_TEST_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp) +endif() + list(APPEND ATen_MOBILE_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 77bbbfc8edbe..05febc539966 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1,9 +1,117 @@ +#ifdef USE_VULKAN_API + #include +#include -#ifdef USE_VULKAN_API +// TODO: These functions should move to a common place. namespace { +bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { + float maxValue = 0.0f; + + for (const auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + + return diff.abs().max().item() < (2e-6 * maxValue); +} + +bool almostEqual(const at::Tensor& a, const at::Tensor& b) { + return checkRtol(a - b, {a, b}); +} + +bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { + return (a - b).abs().max().item() == 0.0f; +} + +} // namespace + +namespace { + +TEST(VulkanAPITest, add) { + const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + const auto c_cpu = at::add(a_cpu, b_cpu, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_vulkan, 2.1f); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + a_cpu.add_(b_cpu, 2.1f); + a_vulkan.add_(b_vulkan, 2.1f); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_scalar) { + const auto a_cpu = at::rand({1, 1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::add(a_cpu, b_scalar, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_scalar, 2.1f); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_scalar_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.add_(b_scalar, 2.1f); + a_vulkan.add_(b_scalar, 2.1f); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, mul_scalar) { + const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::mul(a_cpu, b_scalar); + const auto c_vulkan = at::mul(a_vulkan, b_scalar); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, mul_scalar_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.mul_(b_scalar); + a_vulkan.mul_(b_scalar); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, copy) { + const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); + ASSERT_TRUE(exactlyEqual(cpu, cpu.vulkan().cpu())); +} + +TEST(VulkanAPITest, empty) { + ASSERT_NO_THROW(at::empty({1, 17, 41, 53}, at::device(at::kVulkan).dtype(at::kFloat))); +} + } // namespace #endif /* USE_VULKAN_API */ diff --git a/aten/src/TH/THAllocator.cpp b/aten/src/TH/THAllocator.cpp index 6cdc62ab6da6..20bfd490ed9f 100644 --- a/aten/src/TH/THAllocator.cpp +++ b/aten/src/TH/THAllocator.cpp @@ -288,6 +288,7 @@ THMapAllocator::THMapAllocator(WithFd, const char *filename, int fd, int flags, if (base_ptr_ == MAP_FAILED) { base_ptr_ = nullptr; /* let's be sure it is NULL */ + AT_ERROR("unable to mmap ", size_, " bytes from file <", filename_, ">: ", strerror(errno), " (", errno, ")"); } if (flags_ & TH_ALLOCATOR_MAPPED_KEEPFD) { diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp index b78f8c7a3035..8d5c28daa796 100644 --- a/aten/src/TH/THStorageFunctions.hpp +++ b/aten/src/TH/THStorageFunctions.hpp @@ -8,6 +8,7 @@ #include #include +#include // Note [Weak references for intrusive refcounting] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index 6a776f4d0a17..c0fb51f53e45 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -2,7 +2,6 @@ #define TH_GENERIC_FILE "TH/generic/THLapack.cpp" #else - TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 2494e21791e4..76d7d7bc48d8 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -410,7 +410,8 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau) { if (a == NULL) a = ra_; THArgCheck(THTensor_nDimension(a) == 2, 1, "'input' should be 2 dimensional"); - THArgCheck(!a->is_empty(), 1, "'input' should not be empty"); + THArgCheck(!a->is_empty(), 2, "'input' should not be empty"); + THArgCheck(!tau->is_empty(), 3, "'tau' should not be empty"); THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); diff --git a/benchmarks/operator_benchmark/c2/clip_ranges_test.py b/benchmarks/operator_benchmark/c2/clip_ranges_test.py new file mode 100644 index 000000000000..2bb32f062445 --- /dev/null +++ b/benchmarks/operator_benchmark/c2/clip_ranges_test.py @@ -0,0 +1,51 @@ +import benchmark_caffe2 as op_bench_c2 +import operator_benchmark as op_bench +from benchmark_caffe2 import Caffe2BenchmarkBase # noqa +from caffe2.python import core, dyndep + +dyndep.InitOpsLibrary("@/caffe2/caffe2/fb/operators:clip_ranges_op") + +"""Microbenchmarks for ClipRanges operator.""" + +# Configs for C2 ClipRanges operator +clip_ranges_long_configs = op_bench.cross_product_configs( + LENGTH=range(1, 100), + M=[1], + N=[2], + MAX_LENGTH=range(1, 100), + dtype=["int32"], + tags=["long"] +) + + +clip_ranges_short_configs = op_bench.config_list( + attrs=[ + [6, 1, 2, 1, "int32"], + [7, 1, 2, 2, "int32"], + [8, 1, 2, 3, "int32"], + [9, 1, 2, 4, "int32"], + [10, 1, 2, 5, "int32"], + ], + attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], + tags=["short"], +) + + +class ClipRangesBenchmark(op_bench_c2.Caffe2BenchmarkBase): + def init(self, LENGTH, M, N, MAX_LENGTH, dtype): + self.input = self.tensor([LENGTH, M, N], dtype) + self.max_length = MAX_LENGTH + self.set_module_name("clip_ranges") + + def forward(self): + op = core.CreateOperator("ClipRanges", self.input, self.input, max_length=self.max_length) + return op + + +op_bench_c2.generate_c2_test( + clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/clip_ranges_test.py b/benchmarks/operator_benchmark/pt/clip_ranges_test.py new file mode 100644 index 000000000000..d2c0d575647b --- /dev/null +++ b/benchmarks/operator_benchmark/pt/clip_ranges_test.py @@ -0,0 +1,53 @@ +import operator_benchmark as op_bench +import torch + + +"""Microbenchmarks for ClipRanges operator.""" +torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") + +# Configs for C2 ClipRanges operator +clip_ranges_long_configs = op_bench.cross_product_configs( + LENGTH=range(1, 100), + M=[1], + N=[2], + MAX_LENGTH=range(1, 100), + device=['cpu', 'cuda'], + dtype=[torch.int32], + tags=["long"], +) + + +clip_ranges_short_configs = op_bench.config_list( + attrs=[ + [6, 1, 2, 1, torch.int32], + [7, 1, 2, 2, torch.int32], + [8, 1, 2, 3, torch.int32], + [9, 1, 2, 4, torch.int32], + [10, 1, 2, 5, torch.int32], + ], + attr_names=["LENGTH", "M", "N", "MAX_LENGTH", "dtype"], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"], +) + + +class ClipRangesBenchmark(op_bench.TorchBenchmarkBase): + def init(self, LENGTH, M, N, MAX_LENGTH, device, dtype): + self.input = torch.rand(LENGTH, M, N, device=device).type(dtype) + self.max_length = MAX_LENGTH + self.set_module_name("clip_ranges") + + def forward(self): + output = torch.ops.fb.clip_ranges(self.input, self.max_length) + return output + + +op_bench.generate_pt_test( + clip_ranges_long_configs + clip_ranges_short_configs, ClipRangesBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/nan_to_num_test.py b/benchmarks/operator_benchmark/pt/nan_to_num_test.py new file mode 100644 index 000000000000..72f5daf33afa --- /dev/null +++ b/benchmarks/operator_benchmark/pt/nan_to_num_test.py @@ -0,0 +1,59 @@ +import operator_benchmark as op_bench +import torch +import math + + +"""Microbenchmarks for torch.nan_to_num / nan_to_num_ operators""" + +# Configs for PT torch.nan_to_num / nan_to_num_ operators +nan_to_num_long_configs = op_bench.cross_product_configs( + M=[32, 64, 128], + N=range(32, 128, 32), + dtype=[torch.float, torch.double], + op=["nan_to_num", "nan_to_num_"], + replace_inf=[True, False], + tags=["long"], +) + + +nan_to_num_short_configs = op_bench.cross_product_configs( + M=[16, 64], + N=[64, 64], + dtype=[torch.float, torch.double], + op=["nan_to_num", "nan_to_num_"], + replace_inf=[True, False], + tags=["short"], +) + + +class ReplaceNaNBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, dtype, op, replace_inf): + self.input = torch.randn(M, N, dtype=dtype) + self.input[0][0] = float("nan") + self.op = op + self.replace_inf = replace_inf + self.set_module_name("nan_to_num") + + def forward(self): + # compare inplace + if self.op == "nan_to_num": + if self.replace_inf: + output = torch.nan_to_num(self.input, nan=1.0) + else: + output = torch.nan_to_num(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) + else: + if self.replace_inf: + output = torch.nan_to_num_(self.input, nan=1.0) + else: + output = torch.nan_to_num_(self.input, nan=1.0, posinf=math.inf, neginf=-math.inf) + return output + + +op_bench.generate_pt_test( + nan_to_num_long_configs + nan_to_num_short_configs, + ReplaceNaNBenchmark, +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/tensorexpr/__main__.py b/benchmarks/tensorexpr/__main__.py index 62c58de600d2..a1f0a5ee2fed 100644 --- a/benchmarks/tensorexpr/__main__.py +++ b/benchmarks/tensorexpr/__main__.py @@ -133,6 +133,7 @@ def main(): elif args.cuda_fuser == "nvf": import torch torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_profiling_mode(True) else : diff --git a/benchmarks/tensorexpr/elementwise.py b/benchmarks/tensorexpr/elementwise.py index 6a45a84a1787..af1352dfa949 100644 --- a/benchmarks/tensorexpr/elementwise.py +++ b/benchmarks/tensorexpr/elementwise.py @@ -219,7 +219,7 @@ def __init__(self, mode, device, dtype, N): @classmethod def module(cls): - return "dynamic_simple_element" + return "simple_dynamic_element" def instantiate_input(self): N, = self.rand_shape([self.N]) diff --git a/benchmarks/tensorexpr/reduction.py b/benchmarks/tensorexpr/reduction.py index 57027745a6a7..bc3e4e158a17 100644 --- a/benchmarks/tensorexpr/reduction.py +++ b/benchmarks/tensorexpr/reduction.py @@ -175,7 +175,7 @@ def instantiate_input(self): @staticmethod def module(): - return "dynamic_reduce2d" + return "dynamicreduce2d" class DynamicReduce2DInnerBench(DynamicReduce2DBench): @@ -184,7 +184,7 @@ def __init__(self, mode, device, dtype, dim0, dim1): @staticmethod def module(): - return "dynamic_reduce2d_inner" + return "reduce2d_dynamic_inner" class DynamicReduce2DOuterBench(DynamicReduce2DBench): @@ -193,7 +193,7 @@ def __init__(self, mode, device, dtype, dim0, dim1): @staticmethod def module(): - return "dynamic_reduce2d_outer" + return "reduce2d_dynamic_outer" benchmark.register_benchmark_class(DynamicReduce2DInnerBench) benchmark.register_benchmark_class(DynamicReduce2DOuterBench) diff --git a/c10/core/DefaultDtype.cpp b/c10/core/DefaultDtype.cpp index c4f420ab6e22..583d4452bfbd 100644 --- a/c10/core/DefaultDtype.cpp +++ b/c10/core/DefaultDtype.cpp @@ -3,26 +3,32 @@ namespace c10 { static auto default_dtype = caffe2::TypeMeta::Make(); -static auto default_dtype_as_scalartype = typeMetaToScalarType(default_dtype); +static auto default_dtype_as_scalartype = default_dtype.toScalarType(); static auto default_complex_dtype = caffe2::TypeMeta::Make>(); void set_default_dtype(caffe2::TypeMeta dtype) { - default_dtype = std::move(dtype); - default_dtype_as_scalartype = typeMetaToScalarType(default_dtype); - if(default_dtype_as_scalartype == ScalarType::Double) { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); - } else { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); + default_dtype = dtype; + default_dtype_as_scalartype = default_dtype.toScalarType(); + switch (default_dtype_as_scalartype) { + case ScalarType::Half: + default_complex_dtype = ScalarType::ComplexHalf; + break; + case ScalarType::Double: + default_complex_dtype = ScalarType::ComplexDouble; + break; + default: + default_complex_dtype = ScalarType::ComplexFloat; + break; } } -const caffe2::TypeMeta& get_default_dtype() { +const caffe2::TypeMeta get_default_dtype() { return default_dtype; } ScalarType get_default_dtype_as_scalartype() { return default_dtype_as_scalartype; } -const caffe2::TypeMeta& get_default_complex_dtype() { +const caffe2::TypeMeta get_default_complex_dtype() { return default_complex_dtype; } } // namespace c10 diff --git a/c10/core/DefaultDtype.h b/c10/core/DefaultDtype.h index eda34b217727..d0a17474bda4 100644 --- a/c10/core/DefaultDtype.h +++ b/c10/core/DefaultDtype.h @@ -9,7 +9,7 @@ class TypeMeta; namespace c10 { C10_API void set_default_dtype(caffe2::TypeMeta dtype); -C10_API const caffe2::TypeMeta& get_default_dtype(); +C10_API const caffe2::TypeMeta get_default_dtype(); C10_API ScalarType get_default_dtype_as_scalartype(); -C10_API const caffe2::TypeMeta& get_default_complex_dtype(); +C10_API const caffe2::TypeMeta get_default_complex_dtype(); } // namespace c10 diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 8f2acebd84f0..6903cf9f77ce 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -3,9 +3,12 @@ #include #include #include +#include +#include +#include #include +#include #include -#include #include #include @@ -68,6 +71,8 @@ enum class ScalarType : int8_t { NumOptions }; +constexpr uint16_t NumScalarTypes = static_cast(ScalarType::NumOptions); + namespace impl { // These are used to map ScalarTypes to C++ types. @@ -94,7 +99,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) #undef SPECIALIZE_ScalarTypeToCPPType -} +} // namespace impl template struct CppTypeToScalarType; @@ -162,64 +167,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { -#define DEFINE_CASE(ctype, name) \ - case ScalarType::name: \ - return caffe2::TypeMeta::Make(); - - switch (scalar_type) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) - case ScalarType::Undefined: - return caffe2::TypeMeta(); - default: - AT_ERROR( - "Unrecognized Scalartype ", - scalar_type, - " (please report this error)"); - } -#undef DEFINE_CASE -} - -static inline c10::optional tryTypeMetaToScalarType( - caffe2::TypeMeta dtype) { -#define DEFINE_IF(ctype, name) \ - if (dtype == caffe2::TypeMeta::Make()) { \ - return {ScalarType::name}; \ - } - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF) -#undef DEFINE_IF - if (dtype == caffe2::TypeMeta()) { - return {ScalarType::Undefined}; - } - return c10::nullopt; -} - -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { - if (auto scalar_type = tryTypeMetaToScalarType(dtype)) { - return *scalar_type; - } - AT_ERROR( - "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); -} - -inline optional optTypeMetaToScalarType(optional type_meta) { - if (!type_meta.has_value()) { - return c10::nullopt; - } - return typeMetaToScalarType(*type_meta); -} - -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { - if (auto mt = tryTypeMetaToScalarType(m)) { - return (*mt) == t; - } - return false; -} - -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { - return t == m; -} - #define DEFINE_CONSTANT(_, name) \ constexpr ScalarType k##name = ScalarType::name; diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h new file mode 100644 index 000000000000..b6e7f6cf1993 --- /dev/null +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +// these just expose TypeMeta/ScalarType bridge functions in c10 +// TODO move to typeid.h (or codemod away) when TypeMeta et al +// are moved from caffe2 to c10 (see note at top of typeid.h) + +namespace c10 { + +/** + * convert ScalarType enum values to TypeMeta handles + */ +static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { + return caffe2::TypeMeta::fromScalarType(scalar_type); +} + +/** + * convert TypeMeta handles to ScalarType enum values + */ +static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + return dtype.toScalarType(); +} + +/** + * typeMetaToScalarType(), lifted to optional + */ +static inline optional optTypeMetaToScalarType(optional type_meta) { + if (!type_meta.has_value()) { + return c10::nullopt; + } + return type_meta->toScalarType(); +} + +/** + * convenience: equality across TypeMeta/ScalarType conversion + */ +static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { + return m.isScalarType(t); +} + +static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { + return t == m; +} + +static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { + return !(t == m); +} + +static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { + return !(t == m); +} + +} // namespace c10 diff --git a/c10/core/Stream.h b/c10/core/Stream.h index 6962be72bf72..5baac5325af7 100644 --- a/c10/core/Stream.h +++ b/c10/core/Stream.h @@ -124,11 +124,11 @@ class Stream final { } static Stream unpack(uint64_t bits) { - auto stream_id = static_cast(bits) & 0xFFFFFFFFull; + const auto stream_id = static_cast(bits & 0xFFFFFFFFull); bits >>= 32; - auto device_index = static_cast(bits) & 0xFFFFull; + const auto device_index = static_cast(bits & 0xFFFFull); bits >>= 16; - auto device_type = static_cast(bits); + const auto device_type = static_cast(bits); TORCH_CHECK(isValidDeviceType(device_type)); // Unfortunately, we can't check if the StreamId is valid here; it // will be checked upon first use. diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 8702ed4fdebf..9f2ca1d2ca07 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -47,13 +47,13 @@ const at::Tensor& TensorImpl::grad() const { TensorImpl::TensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl(std::move(storage), key_set, data_type, storage.device()) {} -TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) +TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl({}, key_set, data_type, std::move(device_opt)) {} -TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta& data_type, +TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), sizes_{0}, @@ -61,9 +61,11 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: numel_(0), data_type_(data_type), device_opt_(device_opt) { + + init_bitfields(); + if (!key_set.empty()) { - AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() || - device_opt_.has_value()); + TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it C10_LOG_API_USAGE_ONCE("tensor.create"); } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 871867c9e2c2..da849b049b65 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -322,24 +322,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { TensorImpl( Storage&& storage, DispatchKeySet, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional device_opt); + TensorImpl(DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional device_opt); // Legacy constructors so I don't have to go update call sites. // TODO: When Variable is added, delete these constructors TensorImpl( Storage&& storage, DispatchKey dispatch_key, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl( std::move(storage), DispatchKeySet(dispatch_key), data_type) {} - TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta& data_type, c10::optional device_opt) + TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} private: @@ -347,7 +347,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional); + TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -665,7 +665,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Returns the TypeMeta of a tensor, which describes what data type * it is (e.g., int, float, ...) */ - const caffe2::TypeMeta& dtype() const { + const caffe2::TypeMeta dtype() const { return data_type_; } @@ -1040,8 +1040,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return; } auto newCapacity = sizes_; - newCapacity[0] = std::max( - newDims[0], std::ceil(sizes_[0] * (growthPct + 100) / 100)); + newCapacity[0] = std::max( + newDims[0], static_cast(std::ceil(sizes_[0] * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; auto oldDims = sizes_; @@ -1235,10 +1235,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void ShareExternalPointer( DataPtr&& data_ptr, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, size_t size_bytes) { TORCH_CHECK( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); if (!size_bytes) { @@ -1275,7 +1275,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * If the existing data does not match the desired type, it will be deleted * and a new storage will be created. */ - inline void* raw_mutable_data(const caffe2::TypeMeta& meta) { + inline void* raw_mutable_data(const caffe2::TypeMeta meta) { // For 0-size tensors it's fine to return any pointer (including nullptr) if (data_type_ == meta && storage_initialized()) { return static_cast(static_cast(storage_.data()) + storage_offset_ * meta.itemsize()); @@ -1369,7 +1369,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_storage_and_dtype( at::Storage storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { set_storage_keep_dtype(storage); data_type_ = data_type; } @@ -1675,36 +1675,47 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // INVARIANT: named_tensor_meta_ != nullptr <==> key_set_.has(DispatchKey::Named) DispatchKeySet key_set_; - // You get to have eight byte-size fields here, before you - // should pack this into a bitfield. + // Tensor is contiguous bool is_contiguous_ = true; + // default member initializers for bit-fields only available with -std=c++2a or -std=gnu++2a + inline void init_bitfields() { + is_channels_last_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_ = false; + is_channels_last_3d_contiguous_ = false; + is_non_overlapping_and_dense_ = false; + is_wrapped_number_ = false; + allow_tensor_metadata_change_ = true; + reserved_ = false; + } + // Tensor is stored in the channels last 2d memory format, when dimensions // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_ = false; + bool is_channels_last_ : 1; // Channels last contiguous tensor is channel last tensor which occupies // contiguous memory block. - bool is_channels_last_contiguous_ = false; + bool is_channels_last_contiguous_ : 1; // Tensor is stored in the channels last 3d memory format, when dimensions // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_3d_ = false; + bool is_channels_last_3d_ : 1; // Channels last 3d contiguous tensor is channel last 3d tensor which occupies // contiguous memory block. - bool is_channels_last_3d_contiguous_ = false; + bool is_channels_last_3d_contiguous_ : 1; // Dense tensor is the tensor that store values in a contiguous block of memory. // Non-overlapping tensor is the tensor in which elements occupy individual // non-repetitive memory. - bool is_non_overlapping_and_dense_ = false; + bool is_non_overlapping_and_dense_ : 1; - bool is_wrapped_number_ = false; + bool is_wrapped_number_ : 1; // NOTE [ Metadata Change for a Detached Tensor ] // @@ -1721,14 +1732,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // NOTE: For a full list of tensor metadata fields, please see // `copy_tensor_metadata()` in TensorImpl and its subclasses to find // which fields are copied by value. - bool allow_tensor_metadata_change_ = true; + bool allow_tensor_metadata_change_ : 1; // we decide to keep reserved_ and it will // live in Tensor after the split // The logic is that if Extend() or ReserveSpace() were ever called, // then subsequent Resize()s will not free up Storage. - bool reserved_ = false; - + bool reserved_ : 1; }; // Note [TensorImpl size constraints] @@ -1781,13 +1791,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // strides SmallVector (pre-allocated 4) // storage offset // numel -// data type pointer +// data type // (optional) device // tensor type id // miscellaneous bitfield // static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 31, + sizeof(TensorImpl) == sizeof(int64_t) * 29, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 10342908a49b..c8c7f058513d 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -376,15 +377,24 @@ struct C10_API TensorOptions { /// device guard. /// TensorOptions merge_in(TensorOptions options) const noexcept { - TensorOptions r = options; - if (!r.has_device()) r.set_device(device_opt()); - if (!r.has_dtype()) r.set_dtype(dtype_opt()); - if (!r.has_layout()) r.set_layout(layout_opt()); + TensorOptions merged = *this; + if (options.has_device()) merged.set_device(options.device_opt()); + if (options.has_dtype()) merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) merged.set_layout(options.layout_opt()); // NB: requires grad is right biased; not a logical AND/OR! - if (!r.has_requires_grad()) r.set_requires_grad(requires_grad_opt()); - if (!r.has_pinned_memory()) r.set_pinned_memory(pinned_memory_opt()); - if (!r.has_memory_format()) r.set_memory_format(memory_format_opt()); - return r; + if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt()); + return merged; + } + + // TODO remove after TensorOptions rationalization + TensorOptions merge_memory_format(c10::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(*optional_memory_format); + } + return merged; } // Resolves the tensor type set specified by the current construction axes. @@ -492,8 +502,8 @@ struct C10_API TensorOptions { // NB: We didn't use c10::optional here, because then we can't pack // the has_***_ boolean fields. - caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit Device device_ = at::kCPU; // 32-bit + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 16-bit Layout layout_ = at::kStrided; // 8-bit MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 9f1cb93c10eb..26122ed305e2 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -28,8 +28,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { private: UndefinedTensorImpl(); static UndefinedTensorImpl _singleton; -public: - friend struct UndefinedType; }; } // namespace c10 diff --git a/c10/mobile/CPUCachingAllocator.cpp b/c10/mobile/CPUCachingAllocator.cpp index b2f193299089..bde4067d45dc 100644 --- a/c10/mobile/CPUCachingAllocator.cpp +++ b/c10/mobile/CPUCachingAllocator.cpp @@ -95,8 +95,8 @@ CPUCachingAllocator* GetThreadLocalCachingAllocator() { WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard( CPUCachingAllocator* allocator) { - caching_allocator_ptr = allocator; prev_caching_allocator_ptr_ = GetThreadLocalCachingAllocator(); + caching_allocator_ptr = allocator; } WithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() { diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 233a4442f2a8..2ea283d1a4f0 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -1653,6 +1653,21 @@ TEST(WeakIntrusivePtrTest, givenPtr_whenLocking_thenReturnsCorrectObject) { EXPECT_EQ(var.ptr.get(), locked.get()); } +TEST(WeakIntrusivePtrTest, expiredPtr_whenLocking_thenReturnsNullType) { + IntrusiveAndWeak var = make_weak_intrusive(); + // reset the intrusive_ptr to test if weak pointer still valid + var.ptr.reset(); + EXPECT_TRUE(var.weak.expired()); + intrusive_ptr locked = var.weak.lock(); + EXPECT_FALSE(locked.defined()); +} + +TEST(WeakIntrusivePtrTest, weakNullPtr_locking) { + auto weak_ptr = make_invalid_weak(); + intrusive_ptr locked = weak_ptr.lock(); + EXPECT_FALSE(locked.defined()); +} + TEST( WeakIntrusivePtrTest, givenValidPtr_whenMoveAssigning_thenPointsToSameObject) { diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 308ee883794a..453196510aa8 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -589,15 +589,19 @@ class weak_intrusive_ptr final { } intrusive_ptr lock() const noexcept { - auto refcount = target_->refcount_.load(); - do { - if (refcount == 0) { - // Object already destructed, no strong references left anymore. - // Return nullptr. - return intrusive_ptr(NullType::singleton()); - } - } while (!target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); - return intrusive_ptr(target_); + if (expired()) { + return intrusive_ptr(NullType::singleton()); + } else { + auto refcount = target_->refcount_.load(); + do { + if (refcount == 0) { + // Object already destructed, no strong references left anymore. + // Return nullptr. + return intrusive_ptr(NullType::singleton()); + } + } while (!target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); + return intrusive_ptr(target_); + } } /** diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 8def126c29aa..2c4fbf8a501b 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -14,9 +14,10 @@ #define LLVM_SUPPORT_MATHEXTRAS_H #include - #include #include #include + #include + #include #include #include #include @@ -547,26 +548,26 @@ /// (32 bit edition.) /// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 inline unsigned Log2_32(uint32_t Value) { - return 31 - countLeadingZeros(Value); + return static_cast(31 - countLeadingZeros(Value)); } /// Return the floor log base 2 of the specified value, -1 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64(uint64_t Value) { - return 63 - countLeadingZeros(Value); + return static_cast(63 - countLeadingZeros(Value)); } /// Return the ceil log base 2 of the specified value, 32 if the value is zero. /// (32 bit edition). /// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 inline unsigned Log2_32_Ceil(uint32_t Value) { - return 32 - countLeadingZeros(Value - 1); + return static_cast(32 - countLeadingZeros(Value - 1)); } /// Return the ceil log base 2 of the specified value, 64 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64_Ceil(uint64_t Value) { - return 64 - countLeadingZeros(Value - 1); + return static_cast(64 - countLeadingZeros(Value - 1)); } /// Return the greatest common divisor of the values using Euclid's algorithm. @@ -589,6 +590,7 @@ /// This function takes a 32-bit integer and returns the bit equivalent float. inline float BitsToFloat(uint32_t Bits) { + //TODO: Use bit_cast once C++20 becomes available. float F; static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); memcpy(&F, &Bits, sizeof(Bits)); diff --git a/c10/util/math_compat.h b/c10/util/math_compat.h index 7d1a7b643850..b522cd26f989 100644 --- a/c10/util/math_compat.h +++ b/c10/util/math_compat.h @@ -59,6 +59,14 @@ namespace std { throw std::runtime_error("std::hypot is not implemented on older Android"); } + // TODO: this function needs to be implemented and tested. Currently just throw an error. + inline float igamma(float x, float y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + inline double igamma(double x, double y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float nextafter(float x, float y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); @@ -66,7 +74,7 @@ namespace std { inline double nextafter(double x, double y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); } - + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float exp2(float x) { throw std::runtime_error("std::exp2 is not implemented on older Android"); diff --git a/c10/util/typeid.cpp b/c10/util/typeid.cpp index e2070a1584a2..f3fe048b4cca 100644 --- a/c10/util/typeid.cpp +++ b/c10/util/typeid.cpp @@ -14,42 +14,41 @@ namespace detail { C10_EXPORT void _ThrowRuntimeTypeLogicError(const string& msg) { // In earlier versions it used to be std::abort() but it's a bit hard-core // for a library - AT_ERROR(msg); + TORCH_CHECK(false, msg); } +} // namespace detail +[[noreturn]] void TypeMeta::error_unsupported_typemeta(caffe2::TypeMeta dtype) { + TORCH_CHECK(false, "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); +} -} // namespace detail +// see TypeMeta::addTypeMetaData +std::atomic TypeMeta::nextTypeIndex(NumScalarTypes); -template <> -EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept { - static constexpr detail::TypeMetaData singleton = detail::TypeMetaData( - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - TypeIdentifier::uninitialized(), - "nullptr (uninitialized)"); - return &singleton; +// fixed length array of TypeMetaData instances +detail::TypeMetaData* TypeMeta::typeMetaDatas() { + static detail::TypeMetaData instances[MaxTypeIndex + 1] = { +#define SCALAR_TYPE_META(T, name) \ + /* ScalarType::name */ \ + detail::TypeMetaData( \ + sizeof(T), \ + detail::_PickNew(), \ + detail::_PickPlacementNew(), \ + detail::_PickCopy(), \ + detail::_PickPlacementDelete(), \ + detail::_PickDelete(), \ + TypeIdentifier::Get(), \ + c10::util::get_fully_qualified_type_name()), +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_META) +#undef SCALAR_TYPE_META + // The remainder of the array is padded with TypeMetaData blanks. + // The first of these is the entry for ScalarType::Undefined. + // The rest are consumed by CAFFE_KNOWN_TYPE entries. + }; + return instances; } -CAFFE_KNOWN_TYPE(uint8_t) -CAFFE_KNOWN_TYPE(int8_t) -CAFFE_KNOWN_TYPE(int16_t) -CAFFE_KNOWN_TYPE(int) -CAFFE_KNOWN_TYPE(int64_t) -CAFFE_KNOWN_TYPE(at::Half) -CAFFE_KNOWN_TYPE(float) -CAFFE_KNOWN_TYPE(double) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -// 11 = undefined type id -// 12 = Tensor (defined in tensor.cc) CAFFE_KNOWN_TYPE(std::string) -CAFFE_KNOWN_TYPE(bool) CAFFE_KNOWN_TYPE(uint16_t) CAFFE_KNOWN_TYPE(char) CAFFE_KNOWN_TYPE(std::unique_ptr) @@ -79,15 +78,11 @@ using _guard_long_unique = std::conditional_t< _guard_long_unique_dummy, T>; } // namespace detail + CAFFE_KNOWN_TYPE(detail::_guard_long_unique); CAFFE_KNOWN_TYPE(detail::_guard_long_unique>) CAFFE_KNOWN_TYPE(float*) CAFFE_KNOWN_TYPE(at::Half*) -CAFFE_KNOWN_TYPE(c10::qint8) -CAFFE_KNOWN_TYPE(c10::quint8) -CAFFE_KNOWN_TYPE(c10::qint32) -CAFFE_KNOWN_TYPE(at::BFloat16) -CAFFE_KNOWN_TYPE(c10::quint4x2) } // namespace caffe2 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 51833fb545ad..5bdbdc4271df 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -21,18 +21,14 @@ #include #include #include -#include #include #include #include #include -#include -#include -#include -#include -#include #include +#include + /* * TypeIdentifier is a small type containing an id. * Types must be registered using CAFFE_KNOWN_TYPE() for them to have a type id. @@ -67,7 +63,7 @@ namespace caffe2 { */ class C10_API TypeIdentifier final : public at::IdWrapper { - public: +public: friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId); friend constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs); @@ -87,9 +83,8 @@ class C10_API TypeIdentifier final return TypeIdentifier(c10::util::type_index{0}); } - private: +private: constexpr explicit TypeIdentifier(c10::util::type_index id) : IdWrapper(id) {} - friend class TypeMeta; // TODO Is this friend an issue? }; // Allow usage in std::map / std::set @@ -126,7 +121,16 @@ struct TypeMetaData final { using PlacementDelete = void(void*, size_t); using Delete = void(void*); - TypeMetaData() = delete; + constexpr TypeMetaData() noexcept + : itemsize_(0), + new_(nullptr), + placementNew_(nullptr), + copy_(nullptr), + placementDelete_(nullptr), + delete_(nullptr), + id_(TypeIdentifier::uninitialized()), + name_("nullptr (uninitialized)") {} + constexpr TypeMetaData( size_t itemsize, New* newFn, @@ -136,14 +140,14 @@ struct TypeMetaData final { Delete* deleteFn, TypeIdentifier id, c10::string_view name) noexcept - : itemsize_(itemsize), - new_(newFn), - placementNew_(placementNew), - copy_(copy), - placementDelete_(placementDelete), - delete_(deleteFn), - id_(id), - name_(name) {} + : itemsize_(itemsize), + new_(newFn), + placementNew_(placementNew), + copy_(copy), + placementDelete_(placementDelete), + delete_(deleteFn), + id_(id), + name_(name) {} size_t itemsize_; New* new_; @@ -294,25 +298,24 @@ inline constexpr TypeMetaData::Delete* _PickDelete() noexcept { return &_Delete; } -template -inline C10_TYPENAME_CONSTEXPR TypeMetaData _makeTypeMetaDataInstance() { - C10_HOST_CONSTEXPR_VAR auto typeId = TypeIdentifier::Get(); - C10_TYPENAME_CONSTEXPR auto typeName = c10::util::get_fully_qualified_type_name(); - - return {sizeof(T), - _PickNew(), - _PickPlacementNew(), - _PickCopy(), - _PickPlacementDelete(), - _PickDelete(), - typeId, - typeName}; -} - class _Uninitialized final {}; } // namespace detail +// +// note: this is outside TypeMeta bc gcc seems to have trouble +// with scalarTypeItemSizes as a constexpr static member used by +// a public inline instance method +// + +// item sizes for TypeMeta::itemsize() fast path +static constexpr size_t scalarTypeItemSizes[NumScalarTypes] = { +#define SCALAR_TYPE_SIZE(T, name) sizeof(T), + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_SIZE) +#undef SCALAR_TYPE_SIZE + 0, // Undefined +}; + /** * TypeMeta is a thin class that allows us to store the type of a container such * as a blob, or the data type of a tensor, with a unique run-time id. It also @@ -338,17 +341,22 @@ class C10_API TypeMeta final { TypeMeta(const TypeMeta& src) noexcept = default; /** - * Assignment operator. + * Assignment operators. */ TypeMeta& operator=(const TypeMeta& src) noexcept = default; TypeMeta(TypeMeta&& rhs) noexcept = default; - private: + inline TypeMeta& operator=(ScalarType scalar_type) noexcept { + index_ = static_cast(scalar_type); + return *this; + } + +private: // TypeMeta can only be created by Make, making sure that we do not // create incorrectly mixed up TypeMeta objects. - explicit TypeMeta(const detail::TypeMetaData* data) noexcept - : data_(data) { + explicit TypeMeta(const uint16_t index) noexcept + : index_(index) { } public: @@ -356,48 +364,66 @@ class C10_API TypeMeta final { * Returns the type id. */ TypeIdentifier id() const noexcept { - return data_->id_; + return data().id_; + } + /** + * true if we represent some ScalarType type + */ + inline bool isScalarType() const noexcept { + return index_ < NumScalarTypes; + } + /** + * true if we represent ScalarType scalar_type + */ + inline bool isScalarType(ScalarType scalar_type) const noexcept { + return index_ == static_cast(scalar_type); } /** * Returns the size of the item. */ - size_t itemsize() const noexcept { - return data_->itemsize_; + inline size_t itemsize() const noexcept { + if (C10_LIKELY(isScalarType())) { + return scalarTypeItemSizes[index_]; + } + return data().itemsize_; } + /** + * Returns the new function pointer for individual items. + */ New* newFn() const noexcept { - return data_->new_; + return data().new_; } /** * Returns the placement new function pointer for individual items. */ PlacementNew* placementNew() const noexcept { - return data_->placementNew_; + return data().placementNew_; } /** * Returns the typed copy function pointer for individual iterms. */ Copy* copy() const noexcept { - return data_->copy_; + return data().copy_; } /** * Returns the destructor function pointer for individual items. */ PlacementDelete* placementDelete() const noexcept { - return data_->placementDelete_; + return data().placementDelete_; } Delete* deleteFn() const noexcept { - return data_->delete_; + return data().delete_; } /** * Returns a printable name for the type. */ c10::string_view name() const noexcept { - return data_->name_; + return data().name_; } friend bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept; + const TypeMeta lhs, + const TypeMeta rhs) noexcept; template bool Match() const noexcept { @@ -412,7 +438,7 @@ class C10_API TypeMeta final { } template - static C10_TYPENAME_CONSTEXPR c10::string_view TypeName() noexcept { + static c10::string_view TypeName() noexcept { return c10::util::get_fully_qualified_type_name(); } @@ -437,35 +463,105 @@ class C10_API TypeMeta final { #pragma GCC diagnostic ignored "-Wunknown-warning-option" #pragma GCC diagnostic ignored "-Wundefined-var-template" #endif - return TypeMeta(_typeMetaDataInstance()); + return TypeMeta(_typeMetaData()); #ifndef _MSC_VER #pragma GCC diagnostic pop #endif } - private: - const detail::TypeMetaData* data_; + /** + * convert ScalarType enum values to TypeMeta handles + */ + static inline caffe2::TypeMeta fromScalarType(ScalarType scalar_type) { + const size_t index = static_cast(scalar_type); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index < NumScalarTypes, + "Unrecognized Scalartype ", scalar_type, " (please report this error)"); + return TypeMeta(index); + } + + /** + * convert TypeMeta handles to ScalarType enum values + */ + inline ScalarType toScalarType() { + if (C10_LIKELY(isScalarType())) { + return static_cast(index_); + } + error_unsupported_typemeta(*this); + } + +private: + [[noreturn]] static void error_unsupported_typemeta(caffe2::TypeMeta dtype); + + // hard limit number of registered types + // note: constexpr provokes Windows compilation error "member may not be initialized" + // static constexpr size_t MaxTypeIndex = UINT8_MAX; + #define MaxTypeIndex UINT8_MAX + + static std::atomic nextTypeIndex; + + static detail::TypeMetaData* typeMetaDatas(); template - C10_API static const detail::TypeMetaData* _typeMetaDataInstance() noexcept; + static uint16_t addTypeMetaData() { + const uint16_t index = nextTypeIndex++; + TORCH_CHECK(index <= MaxTypeIndex, + "Maximum number of CAFFE_KNOWN_TYPE declarations has been exceeded. ", + "Please report this issue."); + typeMetaDatas()[index] = detail::TypeMetaData{ + sizeof(T), + detail::_PickNew(), + detail::_PickPlacementNew(), + detail::_PickCopy(), + detail::_PickPlacementDelete(), + detail::_PickDelete(), + TypeIdentifier::Get(), + c10::util::get_fully_qualified_type_name()}; + return index; + } + + // specializations return indexes into typeMetaDataInstances() + template + C10_API static uint16_t _typeMetaData() noexcept; + + // + // TypeMeta just wraps this index + // + + uint16_t index_; + + inline const detail::TypeMetaData& data() const { + return typeMetaDatas()[index_]; + } }; +// specializations of TypeMeta::_typeMetaData for ScalarType types + +#define DEFINE_SCALAR_METADATA_INSTANCE(T, name) \ + template <> \ + constexpr uint16_t TypeMeta::_typeMetaData() noexcept { \ + return static_cast(ScalarType::name); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_METADATA_INSTANCE) +#undef DEFINE_SCALAR_METADATA_INSTANCE + template <> -C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept; +C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData() noexcept { + return static_cast(ScalarType::Undefined); +} inline TypeMeta::TypeMeta() noexcept - : data_(_typeMetaDataInstance()) { + : index_(_typeMetaData()) { } inline bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { - return (lhs.data_ == rhs.data_); + const TypeMeta lhs, + const TypeMeta rhs) noexcept { + return (lhs.index_ == rhs.index_); } inline bool operator!=( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { + const TypeMeta lhs, + const TypeMeta rhs) noexcept { return !operator==(lhs, rhs); } @@ -500,13 +596,11 @@ inline std::ostream& operator<<( #define EXPORT_IF_NOT_GCC #endif -#define CAFFE_KNOWN_TYPE(T) \ - template <> \ - EXPORT_IF_NOT_GCC const detail::TypeMetaData* \ - TypeMeta::_typeMetaDataInstance() noexcept { \ - static C10_TYPENAME_CONSTEXPR detail::TypeMetaData singleton = \ - detail::_makeTypeMetaDataInstance(); \ - return &singleton; \ +#define CAFFE_KNOWN_TYPE(T) \ + template <> \ + EXPORT_IF_NOT_GCC uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ } } // namespace caffe2 diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h index 4aef84663adc..ddeea5d5f56c 100644 --- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h +++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h @@ -189,7 +189,7 @@ class LayerNormFakeFp16Op final : public Operator { int Nout = X.numel(); std::vector inv_scalev(Nout, inv_scale); - std::vector offsetv(Nout, Y_offset - 128.0); + std::vector offsetv(Nout, Y_offset); uint8_t* Y_uint8_data = Y_int8->t.template mutable_data(); fake_fp16::fma_fp16(Nout, Y_fp16.data(), inv_scalev.data(), offsetv.data()); @@ -200,7 +200,6 @@ class LayerNormFakeFp16Op final : public Operator { for (int i = 0; i < Nout; i++) { float halfRes = offsetv[i]; halfRes = round(halfRes); - halfRes = halfRes + 128.0; if (std::isinf(halfRes)) { if (halfRes > 0) { halfRes = qmax; diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 9ff0986116b6..5129a38c5241 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import caffe2.python.fakelowp.init_shared_libs # noqa from caffe2.proto import caffe2_pb2 diff --git a/caffe2/contrib/opencl/context.h b/caffe2/contrib/opencl/context.h index ce788a39a7cd..15bfda2203f0 100644 --- a/caffe2/contrib/opencl/context.h +++ b/caffe2/contrib/opencl/context.h @@ -59,7 +59,7 @@ class OpenCLContext final { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "OpenCLContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/core/context.h b/caffe2/core/context.h index f3f4a9138ce1..b0e99ef1e59e 100644 --- a/caffe2/core/context.h +++ b/caffe2/core/context.h @@ -131,7 +131,7 @@ class CAFFE2_API CPUContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h index bad6872819de..036ac98fdc91 100644 --- a/caffe2/core/context_base.h +++ b/caffe2/core/context_base.h @@ -104,7 +104,7 @@ class CAFFE2_API BaseContext { } void CopyItemsSameDevice( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -117,7 +117,7 @@ class CAFFE2_API BaseContext { } void CopyItemsFromCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -130,7 +130,7 @@ class CAFFE2_API BaseContext { } void CopyItemsToCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index c0930b1a0e61..7406132f8788 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -279,7 +279,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index ea9ae7892a23..a4abc97f73e4 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -1246,7 +1246,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta& meta) { \ + static bool call(Op* op, const TypeMeta meta) { \ static_assert( \ !std::is_same::value, \ "GenericTensorImplementation must be the last in TensorTypes list"); \ @@ -1269,7 +1269,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* /* unused */, const TypeMeta& meta) { \ + static bool call(Op* /* unused */, const TypeMeta meta) { \ CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ } \ template \ @@ -1287,7 +1287,7 @@ struct DispatchHelper, ExtraArgs...> { TensorTypes, \ ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta&) { \ + static bool call(Op* op, const TypeMeta) { \ return op->template DoRunWithOtherType(); \ } \ template \ diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index 97c309c078e4..06e27ef5be54 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -133,8 +133,8 @@ std::function getContinuationTest( // if the blob doesn't exist or is not initialized, return false inline bool getShouldStop(const Blob* b) { if (!b || - b->meta().id() == - TypeIdentifier::uninitialized()) { // not exist or uninitialized + b->meta() == + ScalarType::Undefined) { // not exist or uninitialized return false; } diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 27f8b471b71b..83df5306e177 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -299,14 +299,14 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( void* src, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes = 0, MemoryDeleter d = nullptr) const { CAFFE_ENFORCE_WITH_CALLER( impl_->is_contiguous(), "Right now ShareExternalPointer is only supported for contiguous Tensor."); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); impl_.get()->ShareExternalPointer( @@ -315,7 +315,7 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( at::DataPtr&& data_ptr, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes) { impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, nbytes); } @@ -342,7 +342,7 @@ class CAFFE2_API Tensor final { return impl_.get()->data(); } - inline void* raw_mutable_data(const TypeMeta& meta) const { + inline void* raw_mutable_data(const TypeMeta meta) const { return impl_.get()->raw_mutable_data(meta); } @@ -358,7 +358,7 @@ class CAFFE2_API Tensor final { inline void* raw_mutable_data() const { const auto& data_type = impl_->dtype(); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "Calling raw_mutable_data() without meta, but the current meta is " "of unknown type."); return raw_mutable_data(data_type); @@ -469,7 +469,7 @@ class CAFFE2_API Tensor final { /** * Returns the TypeMeta object associated with the current data type. */ - inline const TypeMeta& dtype() const { + inline const TypeMeta dtype() const { return impl_->dtype(); } @@ -477,7 +477,7 @@ class CAFFE2_API Tensor final { * (To be deprecated) Returns the TypeMeta object associated with the current * data type. */ - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return impl_->dtype(); } diff --git a/caffe2/core/types.cc b/caffe2/core/types.cc index d1007fe76e86..c738fc50a288 100644 --- a/caffe2/core/types.cc +++ b/caffe2/core/types.cc @@ -8,7 +8,7 @@ namespace caffe2 { -TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { +TensorProto::DataType TypeMetaToDataType(const TypeMeta meta) { static_assert( sizeof(int) == 4, "int in this compiler does not equal to 4 bytes."); static std::map data_type_map{ @@ -36,7 +36,7 @@ TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second); } -const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) { +const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt) { static std::map type_meta_map{ {TensorProto_DataType_FLOAT, TypeMeta::Make()}, {TensorProto_DataType_INT32, TypeMeta::Make()}, diff --git a/caffe2/core/types.h b/caffe2/core/types.h index c0e8d7bbfb3d..5dda5a5e0810 100644 --- a/caffe2/core/types.h +++ b/caffe2/core/types.h @@ -47,10 +47,10 @@ inline int32_t GetDimFromOrderString(const std::string& str) { inline constexpr char NameScopeSeparator() { return '/'; } // From TypeMeta to caffe2::DataType protobuffer enum. -CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta); +CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta meta); // From caffe2::DataType protobuffer enum to TypeMeta -CAFFE2_API const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt); +CAFFE2_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt); } // namespace caffe2 diff --git a/caffe2/ideep/utils/ideep_context.h b/caffe2/ideep/utils/ideep_context.h index 823b4bec16bd..d0f1207a08f6 100644 --- a/caffe2/ideep/utils/ideep_context.h +++ b/caffe2/ideep/utils/ideep_context.h @@ -91,7 +91,7 @@ class IDEEPContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 95fcfc1ab923..c311ad23e4ed 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -407,7 +407,7 @@ class UnPackRecordsOp : public Operator { // Precomputer the output sizes to avoid resizing std::vector> outputDims(numTensors); - std::vector metas(numTensors); + std::vector metas(numTensors); CAFFE_ENFORCE( numRows > 0 || InputSize() > 1, @@ -428,7 +428,7 @@ class UnPackRecordsOp : public Operator { // Checks to ensure that dimensions/sizes match CAFFE_ENFORCE_EQ(outputDims[j].size(), input.dim()); - CAFFE_ENFORCE(*metas[j] == input.dtype()); + CAFFE_ENFORCE(metas[j] == input.dtype()); // We look from first dimension, because we concat on the first. for (int k = 1; k < input.dim(); ++k) { CAFFE_ENFORCE_EQ(input.sizes()[k], outputDims[j][k]); @@ -442,7 +442,7 @@ class UnPackRecordsOp : public Operator { std::vector destinations(numTensors); for (int i = 0; i < numTensors; ++i) { Output(i)->Resize(outputDims[i]); - destinations[i] = Output(i)->raw_mutable_data(*metas[i]); + destinations[i] = Output(i)->raw_mutable_data(metas[i]); } for (int i = 0; i < numRows; ++i) { @@ -450,7 +450,7 @@ class UnPackRecordsOp : public Operator { const auto& input = tensors[i][j]; context_.CopyItemsSameDevice( - *metas[j], + metas[j], input.numel(), input.raw_data() /* src */, destinations[j] /* dst */ @@ -468,7 +468,7 @@ class UnPackRecordsOp : public Operator { void getShapeAndMetaFromInput( const Shared2DTensorVectorPtr& inputs, std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto& inputZero = inputs->at(0); const auto numTensors = inputZero.size(); @@ -479,13 +479,13 @@ class UnPackRecordsOp : public Operator { for (int i = 0; i < numTensors; ++i) { outputDims[i] = inputZero[i].sizes().vec(); outputDims[i][0] = 0; - metas[i] = &inputZero[i].dtype(); + metas[i] = inputZero[i].dtype(); } } void getShapeAndMetaFromPrototypeBlobs( std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto numTensors = fields_.size(); CAFFE_ENFORCE_EQ(numTensors, InputSize() - 1); CAFFE_ENFORCE_EQ(numTensors, OutputSize()); @@ -493,7 +493,7 @@ class UnPackRecordsOp : public Operator { const auto& input = Input(i + 1); outputDims[i] = input.sizes().vec(); outputDims[i][0] = 0; - metas[i] = &input.dtype(); + metas[i] = input.dtype(); } } diff --git a/caffe2/operators/dataset_ops.h b/caffe2/operators/dataset_ops.h index 70a294e14136..fc890014dbb2 100644 --- a/caffe2/operators/dataset_ops.h +++ b/caffe2/operators/dataset_ops.h @@ -146,7 +146,7 @@ class TreeWalker { return size; } - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return walker_.input(fieldId_).dtype(); } diff --git a/caffe2/operators/index_ops.h b/caffe2/operators/index_ops.h index 890753caf2fe..2f5705cb4c26 100644 --- a/caffe2/operators/index_ops.h +++ b/caffe2/operators/index_ops.h @@ -18,7 +18,7 @@ using int64_tValue = int64_t; struct IndexBase { public: - IndexBase(int64_tValue maxElements, const TypeMeta& type) + IndexBase(int64_tValue maxElements, const TypeMeta type) : maxElements_{maxElements}, meta_(type), frozen_{false} {} void Freeze() { @@ -35,7 +35,7 @@ struct IndexBase { virtual ~IndexBase() {} - const TypeMeta& Type() const { + const TypeMeta Type() const { return meta_; } diff --git a/caffe2/operators/mod_op.cc b/caffe2/operators/mod_op.cc index 48c1eea5a415..8faaac51572f 100644 --- a/caffe2/operators/mod_op.cc +++ b/caffe2/operators/mod_op.cc @@ -25,8 +25,6 @@ bool ModOp::DoRunWithType() { return true; } -namespace { - REGISTER_CPU_OPERATOR(Mod, ModOp); OPERATOR_SCHEMA(Mod) .NumInputs(1) @@ -95,5 +93,4 @@ X after running op: .Output(0, "Y", "*(type: Tensor``)* Output tensor of data with modulo operation applied."); SHOULD_NOT_DO_GRADIENT(ModOp); -} // namespace } // namespace caffe2 diff --git a/caffe2/operators/mod_op.cu b/caffe2/operators/mod_op.cu new file mode 100644 index 000000000000..90043a29389f --- /dev/null +++ b/caffe2/operators/mod_op.cu @@ -0,0 +1,63 @@ +#include "caffe2/operators/mod_op.h" + +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +namespace { + +template +__global__ void ModOpSimpleKernel(const int N, const int64_t divisor_, + const T* data_ptr, T* output_ptr) { + CUDA_1D_KERNEL_LOOP(i, N) { + output_ptr[i] = data_ptr[i] % divisor_; + } +} + + +template +__global__ void ModOpKernel(const int N, const int64_t divisor_, + const T* data_ptr, T* output_ptr) { + CUDA_1D_KERNEL_LOOP(i, N) { + output_ptr[i] = data_ptr[i] % divisor_; + if (output_ptr[i] && ((output_ptr[i] > 0) != (divisor_ > 0))) { + output_ptr[i] += divisor_; + } + } +} + +} // namespace + +template <> +template +bool ModOp::DoRunWithType() { + auto& data = Input(DATA); + auto N = data.numel(); + const auto* data_ptr = data.template data(); + + auto* output = Output(0, data.sizes(), at::dtype()); + auto* output_ptr = output->template mutable_data(); + + if (sign_follow_divisor_) { + ModOpKernel<<< + CAFFE_GET_BLOCKS(N), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + N, divisor_, data_ptr, output_ptr); + } else { + ModOpSimpleKernel<<< + CAFFE_GET_BLOCKS(N), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + N, divisor_, data_ptr, output_ptr); + } + + return true; + +} + +REGISTER_CUDA_OPERATOR(Mod, ModOp); + +} // namespace caffe2 diff --git a/caffe2/operators/numpy_tile_op.h b/caffe2/operators/numpy_tile_op.h index 8a39b40df0f8..ac9886ec503a 100644 --- a/caffe2/operators/numpy_tile_op.h +++ b/caffe2/operators/numpy_tile_op.h @@ -92,7 +92,7 @@ class NumpyTileOp : public Operator { private: void DoTile( - const TypeMeta& meta, + const TypeMeta meta, int item_size, int outer_dim, int inner_dim, diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc index 40684c50575b..b0d797fce7ff 100644 --- a/caffe2/operators/tile_op.cc +++ b/caffe2/operators/tile_op.cc @@ -71,7 +71,7 @@ bool TileOp::DoRunWithType() { // size from axis up const int inner_size = X.size_from_dim(axis); - const TypeMeta& meta = X.dtype(); + const TypeMeta meta = X.dtype(); const int item_size = X.itemsize(); const char* X_ptr = reinterpret_cast(X.raw_data()); char* Y_ptr = reinterpret_cast(Y->raw_mutable_data(meta)); diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index b691c24e984a..9abcf5ab0b86 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -59,6 +59,7 @@ REGISTER_CPU_OPERATOR(GatherRanges, GatherRangesOp); REGISTER_CPU_OPERATOR(LengthsGather, LengthsGatherOp); REGISTER_CPU_OPERATOR(LengthsToSegmentIds, LengthsToSegmentIdsOp); REGISTER_CPU_OPERATOR(LengthsToRanges, LengthsToRangesOp); +REGISTER_CPU_OPERATOR(LengthsToOffsets, LengthsToOffsetsOp); REGISTER_CPU_OPERATOR(SegmentIdsToLengths, SegmentIdsToLengthsOp); REGISTER_CPU_OPERATOR(SegmentIdsToRanges, SegmentIdsToRangesOp); REGISTER_CPU_OPERATOR(LengthsToWeights, LengthsToWeightsOp); @@ -522,20 +523,20 @@ Another output LENGTHS represents each example length within OUTPUT "LENGTHS", "1-D tensor of size N with lengths over gathered data" " for each row in a batch. sum(LENGTHS) == OUTPUT.size()") - .TensorInferenceFunction(OpSchema::NeedsAllInputShapes([]( - const OperatorDef& /* unused */, const vector& in) { - std::vector out(2); - - int total = 1; - for (auto d : in[0].dims()) { - total *= d; - } - out[0].add_dims(total); - out[0].set_data_type(in[0].data_type()); - out[1].add_dims(in[1].dims(0)); - out[1].set_data_type(in[1].data_type()); - return out; - })); + .TensorInferenceFunction(OpSchema::NeedsAllInputShapes( + [](const OperatorDef& /* unused */, const vector& in) { + std::vector out(2); + + int total = 1; + for (auto d : in[0].dims()) { + total *= d; + } + out[0].add_dims(total); + out[0].set_data_type(in[0].data_type()); + out[1].add_dims(in[1].dims(0)); + out[1].set_data_type(in[1].data_type()); + return out; + })); OPERATOR_SCHEMA(LengthsGather) .NumInputs(3) @@ -636,6 +637,30 @@ For example, `[1, 3, 0, 2]` transforms into `[[0, 1], [1, 3], [4, 0], [4, 2]]`. "ranges", "2D tensor of shape len(lengths) X 2 and the same type as `lengths`"); +OPERATOR_SCHEMA(LengthsToOffsets) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( +Given a vector of segment lengths, returns a vector of offsets from these lengths, +which will have the same size as the input vector. Output is going to have +the same type as input. For long tensors explicit casting from int32 to int64 +might be necessary prior to this op. + +For example, `[1, 3, 0, 2]` transforms into `[0, 1, 4, 4]`. +)DOC") + .Input(0, "lengths", "1D tensor of int32 or int64 segment lengths.") + .Output(0, "offsets", "1D tensor of the same shape and type as `lengths`") + .TensorInferenceFunction([](const OperatorDef& def, + const vector& in) { + const ArgumentHelper args(def); + bool include_last_offset = + args.GetSingleArgument("include_last_offset", false); + vector out_shape(in[0].dims().begin(), in[0].dims().end()); + out_shape[0] += include_last_offset ? 1 : 0; + return vector{ + CreateTensorShape(out_shape, in[0].data_type())}; + }); + OPERATOR_SCHEMA(SegmentIdsToLengths) .NumInputs(1, 2) .NumOutputs(1) diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index a82b5666fb7b..bdc9c0bfbfd9 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -918,6 +918,45 @@ class LengthsToRangesOp : public Operator { } }; +template +class LengthsToOffsetsOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + + template + explicit LengthsToOffsetsOp(Args&&... args) + : Operator(std::forward(args)...), + include_last_offset_(this->template GetSingleArgument( + "include_last_offset", + false)) {} + + bool RunOnDevice() override { + auto& input = Input(0); + auto* output = Output(0); + auto* input_data = input.template data(); + + CAFFE_ENFORCE(input.sizes().size() == 1, "Input must be a vector."); + auto size = input.numel(); + + output->Resize(size + (include_last_offset_ ? 1 : 0)); + auto* output_data = output->template mutable_data(); + + int32_t offset = 0; + for (int i = 0; i < size; ++i) { + auto len = input_data[i]; + output_data[i] = offset; + offset += len; + } + if (include_last_offset_) { + output_data[size] = offset; + } + return true; + } + + private: + bool include_last_offset_; +}; + template class SegmentIdsToLengthsOp : public Operator { public: diff --git a/caffe2/opt/glow_net_transform.cc b/caffe2/opt/glow_net_transform.cc index 12bd060c27d6..ece62abea258 100644 --- a/caffe2/opt/glow_net_transform.cc +++ b/caffe2/opt/glow_net_transform.cc @@ -13,11 +13,6 @@ C10_DEFINE_bool( true, "Attach AdjustBatch ops at input/outputs of the Onnxifi ops"); -C10_DEFINE_bool( - onnxifi_loop_test_mode, - false, - "For test purpose only. Build a dummy net just to test the functionality"); - C10_DEFINE_bool( enforce_fp32_inputs_into_fp16, false, @@ -146,7 +141,6 @@ void onnxifi( opts.load_model_by_blob = load_model_by_blob; opts.enforce_fp32_inputs_into_fp16 = FLAGS_enforce_fp32_inputs_into_fp16; opts.merge_fp32_inputs_into_fp16 = FLAGS_merge_fp32_inputs_into_fp16; - opts.loop_test = FLAGS_onnxifi_loop_test_mode; opts.predictor_net_ssa_rewritten = predictor_net_ssa_rewritten; opts.timeout = FLAGS_onnxifi_timeout_ms; diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index e849f0edb272..c77101984790 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -403,235 +403,6 @@ void mergeFp32InputsAndConvertToFp16( } } -NetDef buildLoopTestNet( - const NetDef& net, - const std::unordered_set& initialization_list, - std::unordered_map* shape_hints, - size_t batch_size) { - NetDef net_dummy; - - // Add non-weigh inputs only - for (const auto& i : net.external_input()) { - if (!initialization_list.count(i)) { - net_dummy.add_external_input(i); - } - } - for (const auto& o : net.external_output()) { - net_dummy.add_external_output(o); - } - - // Now categorize the inputs into the following groups. We don't support - // handling of 3d inputs yet, but it can be done easily by converting n-d - // inputs into 2-d with Reshape or ReduceSum - std::unordered_set batched_2d_inputs; - std::unordered_set other_2d_inputs; - std::unordered_set all_1d_inputs; - auto addCast = [&net_dummy]( - const std::string& i, - std::string& in, - caffe2::TensorProto::DataType dtype) mutable { - int multiplier = 1; - if (dtype != caffe2::TensorProto::FLOAT) { - in += "_fp32"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Clip", - "", - {i}, - {in}, - {MakeArgument("min", 0.0), MakeArgument("max", 1.0)})); - if (dtype == caffe2::TensorProto::INT8 || - dtype == caffe2::TensorProto::UINT8) { - multiplier = sizeof(float) / sizeof(int8_t); - } else if ( - dtype == caffe2::TensorProto::INT16 || - dtype == caffe2::TensorProto::UINT16 || - dtype == caffe2::TensorProto::FLOAT16) { - multiplier = sizeof(float) / sizeof(int16_t); - } else if (dtype == caffe2::TensorProto::INT64) { - // Special case, it should really be 0.5 - multiplier = 0; - } - } - return multiplier; - }; - auto adjustDim = [](int d, int m, TensorShape& shape) { - if (m > 1) { - CAFFE_ENFORCE_EQ(shape.dims(d) % m, 0); - shape.set_dims(d, shape.dims(d) / m); - } else if (m == 0) { - shape.set_dims(d, shape.dims(d) * 2); - } - shape.set_data_type(caffe2::TensorProto::FLOAT); - }; - size_t dim2 = 0; - for (const auto& i : net_dummy.external_input()) { - auto it = shape_hints->find(i); - CAFFE_ENFORCE( - it != shape_hints->end(), "Cannot find shape info for input ", i); - auto& shape = it->second.shape; - std::string in = i; - // Trick here: since backend like glow doesn't support non-float - // arithmatics, we need to be creative and bitcast non-float data type into - // float while maintaining the same bit lengths. We do this by changing the - // shape dim. So that we will always load the same amount of bits onto the - // backend. To avoid numeric complication, we add a Clip. - if (shape.dims_size() == 2) { - auto m = addCast(i, in, shape.data_type()); - adjustDim(1, m, shape); - if (shape.dims(0) == batch_size) { - batched_2d_inputs.emplace(in); - dim2 += shape.dims(1); - } else { - other_2d_inputs.emplace(in); - } - } else if (shape.dims_size() == 1) { - auto m = addCast(i, in, shape.data_type()); - adjustDim(0, m, shape); - all_1d_inputs.emplace(in); - } else { - const std::string fin = i + "_flatten"; - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Flatten", "", {i}, {fin}, {})); - in = fin; - auto m = addCast(fin, in, shape.data_type()); - auto last = shape.dims_size() - 1; - adjustDim(last, m, shape); - size_t ndim = 1; - for (unsigned k = 1; k < shape.dims_size(); ++k) { - ndim *= shape.dims(k); - } - if (shape.dims(0) == batch_size) { - batched_2d_inputs.emplace(in); - dim2 += ndim; - } else { - other_2d_inputs.emplace(in); - } - } - } - - // Add adjusted shape hints - auto* shape_arg = net_dummy.add_arg(); - auto* qshape_arg = net_dummy.add_arg(); - shape_arg->set_name("input_shape_info"); - qshape_arg->set_name("input_qshape_info"); - for (const auto& i : net_dummy.external_input()) { - auto info = shape_hints->at(i); - if (!info.is_quantized) { - shape_arg->mutable_tensors()->Add()->CopyFrom( - wrapShapeInfoIntoTensorProto(i, info)); - } else { - qshape_arg->mutable_qtensors()->Add()->CopyFrom( - wrapShapeInfoIntoQTensorProto(i, info)); - } - } - - // Collect all the input together into a 2d tensor of {batch_size, X} - std::vector concat2d_batched( - batched_2d_inputs.begin(), batched_2d_inputs.end()); - const std::string concat_out = "batch_2d_concat"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Concat", - "", - concat2d_batched, - {concat_out, "batch_2d_concat_split_info"}, - {MakeArgument("axis", 1)})); - std::vector scalars; - for (const auto& i : other_2d_inputs) { - std::string o = i + "_reduced"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "ReduceSum", - "", - {i}, - {o}, - {MakeArgument>("axes", {0, 1}), - MakeArgument("keepdims", 0)})); - scalars.emplace_back(std::move(o)); - } - for (const auto& i : all_1d_inputs) { - std::string o = i + "_reduced"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "ReduceSum", - "", - {i}, - {o}, - {MakeArgument>("axes", {0}), - MakeArgument("keepdims", 0)})); - scalars.emplace_back(std::move(o)); - } - const std::string summed = "summed"; - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Sum", "", scalars, {summed}, {})); - const std::string out = "result_out"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Add", - "", - {concat_out, summed}, - {out}, - {MakeArgument("broadcast", 1)})); - - for (const auto& o : net_dummy.external_output()) { - const auto it = shape_hints->find(o); - CAFFE_ENFORCE( - it != shape_hints->end(), "Cannot find shape info for output ", o); - const auto& shape = it->second.shape; - // TODO: all doable but I'm lazy - if (shape.data_type() != caffe2::TensorProto::FLOAT) { - CAFFE_THROW("We need a Cast op to match the output data type"); - } - if (shape.dims_size() == 2) { - if (shape.dims(0) == batch_size) { - if (shape.dims(1) > dim2) { - CAFFE_THROW( - "We need Tile op to match the output dim ", - shape.dims(1), - " vs ", - dim2); - } else if (shape.dims(1) == dim2) { - net_dummy.add_op()->CopyFrom( - CreateOperatorDef("Copy", "", {out}, {o}, {})); - } else { - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Slice", - "", - {out}, - {o}, - {MakeArgument>("starts", {0, 0}), - MakeArgument>( - "ends", {-1, static_cast(shape.dims(1))})})); - } - } - } else if (shape.dims_size() == 1) { - if (shape.dims(0) == batch_size) { - const std::string oi = o + "_pre"; - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Slice", - "", - {out}, - {oi}, - {MakeArgument>("starts", {0, 0}), - MakeArgument>("ends", {-1, 1})})); - net_dummy.add_op()->CopyFrom(CreateOperatorDef( - "Reshape", - "", - {oi}, - {o}, - {MakeArgument>( - "shape", {static_cast(batch_size)})})); - } else { - CAFFE_THROW( - "We need Slice and Tile op to match the output dim ", - shape.dims(0), - " vs ", - batch_size); - } - } else { - CAFFE_THROW("Only support 1D/2D outputs for now"); - } - } - - return net_dummy; -} - } // namespace void splitSparseLengthsSumSparse(NetDef* net, const Workspace& ws) { @@ -928,19 +699,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( } } - // Rewrite the net into a dummy in loop test mode - ShapeInfoMap new_shape_hints; - if (opts_.loop_test) { - new_shape_hints = shape_hints; - onnxifi_net = buildLoopTestNet( - onnxifi_net, - initialization_list, - &new_shape_hints, - opts_.bound_shape_spec.max_batch_size); - initialization_list.clear(); - } - - // Add parition info + // Add partition info for (const auto& p : partition_infos_) { onnxifi_net.add_partition_info()->CopyFrom(p); } @@ -965,7 +724,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2( initialization_list, onnxifi_net_inputs, onnxifi_net_outputs, - opts_.loop_test ? new_shape_hints : shape_hints); + shape_hints); NetDef net_opt = composeResultNet(onnxifi_op); // Debugging stuff diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index 9af168a4d20e..86e061411dd9 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -39,9 +39,6 @@ struct OnnxifiTransformerOptions final : public BackendTransformOptions { // fp16 or not bool merge_fp32_inputs_into_fp16{false}; - // Enter loop test mode - bool loop_test{false}; - // Whether the net has been ssaRewritten bool predictor_net_ssa_rewritten{false}; diff --git a/caffe2/python/gradient_checker.py b/caffe2/python/gradient_checker.py index afb8d5071492..5f116bd6107c 100644 --- a/caffe2/python/gradient_checker.py +++ b/caffe2/python/gradient_checker.py @@ -5,6 +5,7 @@ +import os import numpy as np from caffe2.python import core, workspace, net_drawer @@ -292,8 +293,20 @@ def CheckSimple( if ensure_outputs_are_inferred: self._assertInferTensorChecks(op, grad_ops) + full_grad_check = os.getenv('CAFFE2_FULL_GRAD_CHECK') == '1' + dims_to_check = inputs[input_to_check].size for current_dim in range(dims_to_check): + # Grad check is very expensive (as it involves running the op from + # scratch for each of the input tensor elements). Thus, let's + # run it by default only on a small subset of dimensions. Here we + # apply very scientific approach: the first and the last 3 elements + # of each tensor. Pass CAFFE2_FULL_GRAD_CHECK=1 env var to enable + # the full check + if not full_grad_check and current_dim >= 3 and \ + current_dim + 3 < dims_to_check: + grad_estimate.flat[current_dim] = grad.flat[current_dim] + continue # Positive gradient inputs[input_to_check].flat[current_dim] += self._stepsize pos_loss, _ = self.GetLossAndGrad( diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 045677f8422a..9298134f651c 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -994,6 +994,38 @@ def op_ref(x): inputs=[np.array(lengths, dtype=np.int32)], reference=op_ref) + @given( + lengths=st.lists( + st.integers(min_value=0, max_value=10), min_size=0, max_size=10 + ), + include_last_offset=st.booleans(), + **hu.gcs_cpu_only + ) + @settings(deadline=None) + def test_lengths_to_offsets(self, lengths, include_last_offset, gc, dc): + op = core.CreateOperator( + "LengthsToOffsets", + ["lengths"], + ["ranges"], + include_last_offset=include_last_offset, + ) + + def op_ref(x): + if not x.size: + arr = [x.reshape(0)] + else: + arr = [np.concatenate(([0], np.cumsum(x)[:-1]))] + if include_last_offset: + arr[0] = np.concatenate((arr[0], np.array([np.sum(x)]))) + return tuple(arr) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[np.array(lengths, dtype=np.int32)], + reference=op_ref, + ) + @given(prediction=hu.arrays(dims=[10, 3], elements=hu.floats(allow_nan=False, allow_infinity=False, diff --git a/caffe2/python/layers/last_n_window_collector.py b/caffe2/python/layers/last_n_window_collector.py index a16b731a2f78..5e6874b4cca0 100644 --- a/caffe2/python/layers/last_n_window_collector.py +++ b/caffe2/python/layers/last_n_window_collector.py @@ -1,10 +1,6 @@ ## @package last_n_window_collector # Module caffe2.python.layers.last_n_window_collector - - - - from caffe2.python import core, schema from caffe2.python.layers.layers import ModelLayer diff --git a/caffe2/python/operator_test/mod_op_test.py b/caffe2/python/operator_test/mod_op_test.py index 914bffd2067c..03ff766c11e4 100644 --- a/caffe2/python/operator_test/mod_op_test.py +++ b/caffe2/python/operator_test/mod_op_test.py @@ -1,12 +1,7 @@ - - - - - import numpy from caffe2.python import core -from hypothesis import given +from hypothesis import given, settings import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st @@ -16,7 +11,8 @@ @st.composite def _data(draw): return draw( - hu.tensor(dtype=np.int64, + hu.tensor( + dtype=np.int64, elements=st.integers( min_value=np.iinfo(np.int64).min, max_value=np.iinfo(np.int64).max ) @@ -25,6 +21,7 @@ def _data(draw): class TestMod(hu.HypothesisTestCase): + @settings(deadline=None) @given( data=_data(), divisor=st.integers( @@ -32,7 +29,7 @@ class TestMod(hu.HypothesisTestCase): ), inplace=st.booleans(), sign_follow_divisor=st.booleans(), - **hu.gcs_cpu_only + **hu.gcs ) def test_mod( self, data, divisor, inplace, sign_follow_divisor, gc, dc diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 2923b98c565f..65a246e4a39c 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -118,7 +118,7 @@ static_assert( sizeof(int) == sizeof(int32_t), "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& meta) { +int CaffeToNumpyType(const TypeMeta meta) { #ifdef USE_NUMPY static std::map numpy_type_map{ {TypeMeta::Id(), NPY_BOOL}, @@ -143,7 +143,7 @@ int CaffeToNumpyType(const TypeMeta& meta) { #endif // USE_NUMPY } -const TypeMeta& NumpyTypeToCaffe(int numpy_type) { +const TypeMeta NumpyTypeToCaffe(int numpy_type) { #ifdef USE_NUMPY static std::map caffe_type_map{ {NPY_BOOL, TypeMeta::Make()}, diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index b8f9dbaf3719..b3926e941194 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -103,8 +103,8 @@ static_assert( "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& dtype); -const TypeMeta& NumpyTypeToCaffe(int numpy_type); +int CaffeToNumpyType(const TypeMeta dtype); +const TypeMeta NumpyTypeToCaffe(int numpy_type); class TensorFetcher : public BlobFetcherBase { public: @@ -114,7 +114,7 @@ class TensorFetcher : public BlobFetcherBase { // Checks whether the data with type `dtype` needs to be copied in the context // of `tensor` - bool NeedsCopy(const Tensor* tensor, const TypeMeta& dtype) const { + bool NeedsCopy(const Tensor* tensor, const TypeMeta dtype) const { #ifdef USE_NUMPY return tensor->GetDeviceType() != CPU || CaffeToNumpyType(dtype) == NPY_OBJECT; @@ -200,9 +200,9 @@ class TensorFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta& dtype = NumpyTypeToCaffe(npy_type); + const TypeMeta dtype = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE( - dtype.id() != TypeIdentifier::uninitialized(), + dtype != ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); diff --git a/caffe2/python/pybind_state_dlpack.cc b/caffe2/python/pybind_state_dlpack.cc index 7b1ec2b8e141..a7204481224f 100644 --- a/caffe2/python/pybind_state_dlpack.cc +++ b/caffe2/python/pybind_state_dlpack.cc @@ -14,7 +14,7 @@ const DLDeviceType* CaffeToDLDeviceType(int device_type) { return it == dl_device_type_map.end() ? nullptr : &it->second; } -const DLDataType* CaffeToDLType(const TypeMeta& meta) { +const DLDataType* CaffeToDLType(const TypeMeta meta) { static std::map dl_type_map{ {TypeMeta::Id(), DLDataType{0, 8, 1}}, {TypeMeta::Id(), DLDataType{0, 16, 1}}, @@ -30,7 +30,7 @@ const DLDataType* CaffeToDLType(const TypeMeta& meta) { return it == dl_type_map.end() ? nullptr : &it->second; } -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type) { +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type) { try { if (dl_type.lanes != 1) { throw std::invalid_argument("invalid type"); diff --git a/caffe2/python/pybind_state_dlpack.h b/caffe2/python/pybind_state_dlpack.h index 54f3157e7634..bcdbc50a61d4 100644 --- a/caffe2/python/pybind_state_dlpack.h +++ b/caffe2/python/pybind_state_dlpack.h @@ -16,9 +16,9 @@ namespace py = pybind11; const DLDeviceType* CaffeToDLDeviceType(int device_type); -const DLDataType* CaffeToDLType(const TypeMeta& meta); +const DLDataType* CaffeToDLType(const TypeMeta meta); -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type); +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type); // TODO: remove context template @@ -40,7 +40,7 @@ class DLPackWrapper { if (tensor->numel() <= 0) { tensor->Resize(0); } - if (tensor->dtype().id() == TypeIdentifier::uninitialized()) { + if (tensor->dtype() == ScalarType::Undefined) { // treat uninitialized tensor as float tensor tensor->template mutable_data(); } diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index 8d09b0aaa326..bbeaf524f055 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -97,7 +97,7 @@ class IDeepFetcher : public BlobFetcherBase { }; class IDeepFeeder : public BlobFeederBase { - itensor::data_type type_transform(const TypeMeta &meta) { + itensor::data_type type_transform(const TypeMeta meta) { if (meta == TypeMeta::Make()) return itensor::data_type::f32; else if (meta == TypeMeta::Make()) @@ -119,10 +119,10 @@ class IDeepFeeder : public BlobFeederBase { PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array); auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE_NE( - meta.id(), - TypeIdentifier::uninitialized(), + meta, + ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); @@ -172,7 +172,7 @@ class IDeepFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); // TODO: if necessary, use dispatcher. if ((in_place && blob->IsType()) diff --git a/caffe2/python/session.py b/caffe2/python/session.py index de3b09931a30..fb2b57c4f5ee 100644 --- a/caffe2/python/session.py +++ b/caffe2/python/session.py @@ -192,7 +192,7 @@ def _compile_task_group(cls, task_group, setup_net_list=None): task = task_group.to_task() plan = core.Plan('task_group_plan') plan.AddStep(task.get_step()) - return (plan, task.output_list(), task.workspace_type) + return (plan, task.output_list(), task.workspace_type()) def _run_compiled(self, compiled): plan, output_list, workspace_type = compiled diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index 99983e84f097..0aa46ee2d4b3 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -335,7 +335,7 @@ def StringifyNetName(name): def GetNetName(net): if isinstance(net, basestring): return net - if type(net).__name__ == "Net": + if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference": return net.Name() if isinstance(net, caffe2_pb2.NetDef): return net.name diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc index c7e6804c1dcf..4a5a6e6b7ad0 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc @@ -190,6 +190,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { if (!dequantize_output_) { Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } DoNothing<> doNothingObj{}; if (quantize_channelwise_ || filter_qparams_[0].zero_point) { @@ -443,6 +446,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { #endif Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t sum = 0; diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 63f5a34aa23b..7928d5e3de86 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -306,6 +306,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); valid("opening archive ", file_name.c_str()); + TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { file_stream_.write(static_cast(buf), nbytes); return !file_stream_ ? 0 : nbytes; diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index dbfd55e2d0d5..db02f7a8fb16 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -178,9 +178,6 @@ if(INTERN_BUILD_ATEN_OPS) --force_schema_registration --op_registration_whitelist ${OP_REGISTRATION_WHITELIST}) endif() - if(USE_VULKAN) - set(GEN_VULKAN_FLAGS --vulkan) - endif() set(GEN_COMMAND "${PYTHON_EXECUTABLE}" -m tools.codegen.gen diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a52dbc050cbf..0ce2d8b44a32 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1635,6 +1635,7 @@ if(NOT INTERN_BUILD_MOBILE) find_package(LAPACK) if(LAPACK_FOUND) set(USE_LAPACK 1) + list(APPEND Caffe2_PRIVATE_DEPENDENCY_LIBS ${LAPACK_LIBRARIES}) endif() if(NOT USE_CUDA) diff --git a/cmake/Modules/FindBLAS.cmake b/cmake/Modules/FindBLAS.cmake index e93e98a6095d..38e826d1f5f4 100644 --- a/cmake/Modules/FindBLAS.cmake +++ b/cmake/Modules/FindBLAS.cmake @@ -153,7 +153,7 @@ if((NOT BLAS_LIBRARIES) BLAS sgemm "" - "openblas;pthread") + "openblas;pthread;m") if(BLAS_LIBRARIES) set(BLAS_INFO "open") endif(BLAS_LIBRARIES) diff --git a/cmake/Modules/FindLAPACK.cmake b/cmake/Modules/FindLAPACK.cmake index c057f207132f..b0e607d90587 100644 --- a/cmake/Modules/FindLAPACK.cmake +++ b/cmake/Modules/FindLAPACK.cmake @@ -123,6 +123,30 @@ if(BLAS_FOUND) IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) check_function_exists("cheev_" OPEN_LAPACK_WORKS) + if(OPEN_LAPACK_WORKS) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(NOT LAPACK_CGESDD_WORKS) + find_library(GFORTRAN_LIBRARY + NAMES libgfortran.a gfortran + PATHS /usr/lib/gcc/aarch64-linux-gnu/9/ + /usr/lib/gcc/x86_64-redhat-linux/9/ + /usr/lib/gcc/aarch64-linux-gnu/8/ + /usr/lib/gcc/x86_64-redhat-linux/8/ + /usr/lib/gcc/aarch64-linux-gnu/7/ + /usr/lib/gcc/x86_64-redhat-linux/7/ + ) + list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}") + unset(LAPACK_CGESDD_WORKS CACHE) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(LAPACK_CGESDD_WORKS) + list(APPEND LAPACK_LIBRARIES "${GFORTRAN_LIBRARY}") + else() + message(WARNING "OpenBlas has been compiled with Lapack support, but cgesdd can not be used") + set(OPEN_LAPACK_WORKS NO) + endif() + endif() + endif() + set(CMAKE_REQUIRED_LIBRARIES) if(OPEN_LAPACK_WORKS) SET(LAPACK_INFO "open") diff --git a/codecov.yml b/codecov.yml index 7ed3d662bb39..525f85e01898 100644 --- a/codecov.yml +++ b/codecov.yml @@ -3,13 +3,18 @@ coverage: project: default: threshold: 1% +codecov: + notify: + after_n_builds: 2 comment: layout: "diff" behavior: once require_changes: true require_base: yes require_head: yes + after_n_builds: 2 branches: - "master" fixes: - "/opt/conda/lib/python3.8/site-packages/::project/" + - "C:/Users/circleci/project/build/win_tmp/build/::project/" diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 834b6a60ac93..14d3ca1767e9 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -14,3 +14,4 @@ Functions .. autofunction:: det .. autofunction:: norm +.. autofunction:: tensorsolve diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index b78ed2c08586..6cfbe186544f 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -45,8 +45,8 @@ The corresponding implementation is chosen automatically based on the PyTorch bu .. note:: - PyTorch 1.3 doesn't provide quantized operator implementations on CUDA yet - - this is direction of future work. Move the model to CPU in order to test the + At the moment PyTorch doesn't provide quantized operator implementations on CUDA - + this is the direction for future work. Move the model to CPU in order to test the quantized functionality. Quantization-aware training (through :class:`~torch.quantization.FakeQuantize`) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 7110631088d7..3bc806067870 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -349,6 +349,8 @@ view of a storage and defines numeric operations on it. .. automethod:: hypot_ .. automethod:: i0 .. automethod:: i0_ + .. automethod:: igamma + .. automethod:: igamma_ .. automethod:: ifft .. automethod:: index_add_ .. automethod:: index_add diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 1a2d0eb72f7d..6ba16aaca3e2 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -85,6 +85,7 @@ Indexing, Slicing, Joining, Mutating Ops cat chunk + column_stack dstack gather hstack @@ -94,6 +95,7 @@ Indexing, Slicing, Joining, Mutating Ops narrow nonzero reshape + row_stack split squeeze stack @@ -310,6 +312,7 @@ Pointwise Ops logit hypot i0 + igamma mul multiply mvlgamma diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec index f74e2dc9f37e..236f1de7988f 100644 --- a/ios/LibTorch.podspec +++ b/ios/LibTorch.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch' - s.version = '1.6.1' + s.version = '1.7.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' diff --git a/mypy.ini b/mypy.ini index 535310411720..1fd1ce884520 100644 --- a/mypy.ini +++ b/mypy.ini @@ -119,9 +119,6 @@ ignore_errors = True [mypy-torch.nn.parallel._functions] ignore_errors = True -[mypy-torch.nn.parallel.comm] -ignore_errors = True - [mypy-torch.nn.qat.modules.activations] ignore_errors = True diff --git a/requirements.txt b/requirements.txt index 07127f738ff9..759baf3984c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ requests setuptools six typing_extensions -dataclasses +dataclasses; python_version<"3.7" diff --git a/setup.py b/setup.py index 54f170e3d6c6..2250fff2616a 100644 --- a/setup.py +++ b/setup.py @@ -731,6 +731,7 @@ def print_box(msg): with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: long_description = f.read() + version_range_max = max(sys.version_info[1], 8) + 1 setup( name=package_name, version=version, @@ -890,7 +891,7 @@ def print_box(msg): 'Topic :: Software Development :: Libraries :: Python Modules', 'Programming Language :: C++', 'Programming Language :: Python :: 3', - ] + ['Programming Language :: Python :: 3.{}' for i in range(python_min_version[1], 9)], + ] + ['Programming Language :: Python :: 3.{}'.format(i) for i in range(python_min_version[1], version_range_max)], license='BSD-3', keywords='pytorch machine learning', ) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 5beee3d1c683..9f1f6cc99fe5 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -128,6 +128,7 @@ ("aten::_foreach_addcdiv_", datetime.date(2020, 10, 15)), ("aten::_foreach_addcdiv", datetime.date(2020, 10, 15)), ("aten::_foreach_addcmul", datetime.date(2020, 10, 15)), + ("aten::_foreach_zero_", datetime.date(2020, 11, 5)), ("aten::conj", datetime.date(2020, 11, 10)), ("aten::add_relu", datetime.date(2020, 10, 28)), ("aten::add_relu_", datetime.date(2020, 10, 28)), diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 8a8aa75541ac..502db310de4e 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -7,6 +7,7 @@ #include using namespace torch::autograd; +using namespace torch::test; #define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b))) #define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b))) @@ -154,6 +155,40 @@ TEST(AutogradAPITests, RetainGrad) { ASSERT_VARIABLE_EQ(input * 18, input.grad()); } +TEST(AutogradAPITests, AnomalyMode) { + // Needs to have backtrace as warning and then throw an error + torch::autograd::AnomalyMode::set_enabled(true); + { + WarningCapture warnings; + auto x = torch::tensor({5.0}, torch::requires_grad()); + auto y = x * x; + auto z = y * y; + y += 1; + ASSERT_THROWS_WITH(z.backward(), "inplace"); + ASSERT_TRUE( + warnings.str().find("Traceback of forward") != std::string::npos); + } + { + WarningCapture warnings; + // Double backward + auto x = torch::tensor({0.0}, torch::requires_grad()); + auto y = x.pow(1.5); + auto gr = + grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); + ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan"); + auto msgs = warnings.messages(); + ASSERT_EQ(msgs.size(), 2); + ASSERT_TRUE( + msgs[0].find("Traceback of forward call that caused the error") != + std::string::npos); + ASSERT_TRUE( + msgs[1].find( + "Traceback of forward call that induced the previous calculation") != + std::string::npos); + } + torch::autograd::AnomalyMode::set_enabled(false); +} + TEST(CustomAutogradTest, CustomFunction) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index a96a3b4a5635..776df23e1737 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -44,5 +44,47 @@ TEST(CustomClassTest, TorchbindIValueAPI) { test_with_obj(new_stack_ivalue, "boo"); } +class TorchBindTestClass : public torch::jit::CustomClassHolder { + public: + std::string get() { + return "Hello, I am your test custom class"; + } +}; + +constexpr char class_doc_string[] = R"( + I am docstring for TorchBindTestClass + Args: + What is an argument? Oh never mind, I don't take any. + + Return: + How would I know? I am just a holder of some meaningless test methods. + )"; +constexpr char method_doc_string[] = + "I am docstring for TorchBindTestClass get_with_docstring method"; + +namespace { +static auto reg = + torch::class_( + "_TorchBindTest", + "_TorchBindTestClass", + class_doc_string) + .def("get", &TorchBindTestClass::get) + .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string); + +} // namespace + +// Tests DocString is properly propagated when defining CustomClasses. +TEST(CustomClassTest, TestDocString) { + auto class_type = getCustomClass( + "__torch__.torch.classes._TorchBindTest._TorchBindTestClass"); + AT_ASSERT(class_type); + AT_ASSERT(class_type->doc_string() == class_doc_string); + + AT_ASSERT(class_type->getMethod("get").doc_string().empty()); + AT_ASSERT( + class_type->getMethod("get_with_docstring").doc_string() == + method_doc_string); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index ca4fb2e7620d..54265530eb12 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -1118,6 +1119,33 @@ TEST(RecordFunctionTest, Basic) { clearCallbacks(); } +TEST(RecordFunctionTest, OperatorNameOverload) { + std::set operator_names; + + at::addGlobalCallback(at::RecordFunctionCallback( + [&operator_names](const at::RecordFunction& fn) { + c10::optional op_name = + fn.operator_name(); + if (op_name.has_value()) { + operator_names.insert(c10::toString(*op_name)); + } else { + operator_names.insert("No Operator Name"); + } + }) + .scopes({at::RecordScope::FUNCTION})); + auto t = torch::randn({1, 2, 3}, at::kCPU); + t.set_requires_grad(false); + auto t2 = t.pow(2); + + at::clearCallbacks(); + EXPECT_TRUE(operator_names.count("No Operator Name") == 0) + << "Expected that all traced operators had an associated OperatorName object"; + EXPECT_TRUE(operator_names.count("aten::randn") == 1) + << "Expected aten::randn to have been called and recorded, but it was not"; + EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1) + << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not"; +} + class TestThreadLocalDebugInfo : public c10::DebugInfoBase { public: int getModelId() const { diff --git a/test/distributed/_pipeline/sync/test_worker.py b/test/distributed/_pipeline/sync/test_worker.py index 0247a71ba4a8..fb306b52fad9 100644 --- a/test/distributed/_pipeline/sync/test_worker.py +++ b/test/distributed/_pipeline/sync/test_worker.py @@ -13,6 +13,7 @@ from torch.distributed._pipeline.sync.microbatch import Batch from torch.distributed._pipeline.sync.stream import CPUStream from torch.distributed._pipeline.sync.worker import Task, spawn_workers +from torch.testing._internal.common_utils import TEST_WITH_TSAN class fake_device: @@ -24,6 +25,7 @@ class fake_device: index = None +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") def test_join_running_workers(): count = 0 @@ -47,6 +49,7 @@ def call_in_worker(i, f): assert count == 10 +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") def test_join_running_workers_with_exception(): class ExpectedException(Exception): pass diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index ff9fe993460c..270acf9bf5fc 100644 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -3398,6 +3398,21 @@ def _gpu_model_with_ddp_comm_hook(self, process_group, hook=None, gradient_as_bu return gpu_model + def _gpu_model_with_builtin_ddp_comm_hook(self, process_group, hook=None, gradient_as_bucket_view=False): + device_id = gpus_for_rank(self.world_size)[self.rank][0] + gpu_model = DistributedDataParallel( + ModuleForDdpCommHook().to(device_id), + device_ids=[device_id], + process_group=process_group, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + # Register a built-in DDP communication hook if defined + if hook is not None: + gpu_model._register_builtin_comm_hook(hook) + + return gpu_model + def _run_and_verify_hook(self, model, input, expected_grad): # Run forward output = model(input, self.rank) @@ -3474,18 +3489,46 @@ def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future: # check whether the grads are equal to what DDP without hook would return. self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether built-in DDP communication hooks ALLREDUCE and FP16_COMPRESS + can give the same result result with the case of no hook registered. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + for comm_hook_type in [dist.BuiltinCommHookType.ALLREDUCE, dist.BuiltinCommHookType.FP16_COMPRESS]: + # Get GPU model with the built-in allreduce communication hook. + gpu_model = self._gpu_model_with_builtin_ddp_comm_hook( + process_group, comm_hook_type, gradient_as_bucket_view) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + @requires_nccl() @skip_if_lt_x_gpu(2) @skip_if_rocm def test_ddp_comm_hook_allreduce_hook_nccl(self): self._test_ddp_comm_hook_allreduce_hook_nccl() + @requires_nccl() + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_builtin_ddp_comm_hooks_nccl(self): + self._test_builtin_ddp_comm_hooks_nccl() + @requires_nccl() @skip_if_lt_x_gpu(2) @skip_if_rocm def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self): self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True) + @requires_nccl() + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self): + self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True) + @requires_nccl() @skip_if_lt_x_gpu(2) @skip_if_rocm @@ -3603,7 +3646,7 @@ def dummy_hook(state, bucket): model._register_comm_hook(None, dummy_hook) with self.assertRaisesRegex( - RuntimeError, "register_comm_hook can only be called once." + RuntimeError, "register_comm_hook or register_builtin_comm_hook can only be called once." ): model._register_comm_hook(None, dummy_hook) diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index 89330ddbd2d9..9f61ec77a1f6 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -5,7 +5,9 @@ import torch import torch._C +from torch.testing import FileCheck from pathlib import Path + from torch.testing._internal.common_utils import ( IS_FBCODE, IS_MACOS, @@ -34,6 +36,13 @@ def to_test_backend_multi(module, method_compile_spec): return torch._C._jit_to_backend("test_backend", module, method_compile_spec) +def to_test_backend_selective(module, method_compile_spec, submodules): + def _to_test_backend(module): + return to_test_backend(module, method_compile_spec) + + return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) + + class BasicModule(torch.nn.Module): """ A simple Module used to test to_backend lowering machinery. @@ -81,9 +90,9 @@ def check_function(self, function_name, input): backend_method = self.lowered_module.__getattr__(function_name) # Run methods. - python_output = python_method(input, input) - jit_output = jit_method(input, input) - backend_output = backend_method(input, input) + python_output = python_method(*input) + jit_output = jit_method(*input) + backend_output = backend_method(*input) # The answers returned by Python, JIT and to_backend should all match. self.assertEqual(python_output, backend_output) @@ -95,6 +104,24 @@ def save_load(self): """ self.lowered_module = self.getExportImportCopy(self.lowered_module) + def test_execution(self): + """ + Stub for correctness tests. + """ + pass + + def test_save_load(self): + """ + Stub for serialization tests. + """ + pass + + def test_errors(self): + """ + Stub for testing error checking. + """ + pass + class BasicModuleTest(JitBackendTestCase): """ @@ -116,9 +143,9 @@ def test_execution(self): input = torch.randn(5) # Test all three module methods. - self.check_function("accum", input) - self.check_function("sub_accum", input) - self.check_function("forward", input) + self.check_function("accum", (input, input)) + self.check_function("sub_accum", (input, input)) + self.check_function("forward", (input, input)) @skipIfRocm def test_save_load(self): @@ -166,8 +193,12 @@ def setUp(self): self.module = NestedModuleTest.NestedModule(BasicModule()) # Both modules in self.scripted_module are ScriptModules. self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule())) + + # First, script another instance of NestedModule with share_types=False so that it can be + # selectively lowered without modifying the type of self.scripted_module. lowered_module = to_test_backend_multi( - self.scripted_module, {"forward": {"": ""}} + torch.jit.script(BasicModule()), + {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, ) # self.lowered_module is a ScriptModule, but its submodule is a lowered module. self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module)) @@ -177,7 +208,7 @@ def test_execution(self): input = torch.randn(5) # Test forward. - self.check_function("forward", input) + self.check_function("forward", (input, input)) def test_save_load(self): # Lowered module should produce the same outputs. @@ -190,6 +221,161 @@ def test_save_load(self): self.test_execution() +class SelectiveLoweringTest(JitBackendTestCase): + """ + Tests for the selective lowering API. + """ + class OuterModule(torch.nn.Module): + def __init__(self, sub1, sub2, other): + super().__init__() + self.sub1 = sub1 + self.sub2 = sub2 + self.other = other + + def forward(self, x, y): + # Call the module that will be lowered directly to test + # type remapping in modules that are not its parent. + a, b = self.sub1.submodule.forward(x, y) + c, d = self.sub2.forward(x, y) + e, f = self.other.forward(x, y) + return a + c + e, b + d + f + + class MiddleModule(torch.nn.Module): + def __init__(self, submodule): + super().__init__() + self.submodule = submodule + + def forward(self, x, y): + return self.submodule.forward(x, y) + + def setUp(self): + super().setUp() + OuterModule = SelectiveLoweringTest.OuterModule + MiddleModule = SelectiveLoweringTest.MiddleModule + + def script_without_type_sharing(mod): + return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False) + # Create Python, JIT and backend versions of a hierarchy that looks like this: + # --------- OuterModule -------- + # | | | + # MiddleModule MiddleModule MiddleModule + # | | | + # BasicModule BasicModule BasicModule + # + # Two BasicModules will be lowered and the third will not. + self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [ + "sub1.submodule", "sub2.submodule"]) + + def test_execution(self): + input = torch.randn(5) + self.check_function("forward", (input, input)) + + self.test_selective_lowering_type_remap() + + def test_save_load(self): + self.test_execution() + self.save_load() + self.test_execution() + + self.test_selective_lowering_type_remap() + + def test_selective_lowering_type_remap(self): + """ + Check that type remapping and replacement occurred during selective lowering. + """ + # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it + # calling the lowered module directly. + FileCheck() \ + .check("OuterModule") \ + .check("BasicModule") \ + .run(self.scripted_module.graph) + FileCheck() \ + .check("OuterModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .run(self.lowered_module.graph) + + # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub1.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub1.graph) + + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub2.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub2.graph) + + # Check that self.lowered_module.sub1/sub2.submodule were lowered. Its graph should mention + # __torch__.torch.classes.__backends__.test_backend, the TorchBind class for executing functions + # on the test JIT backend. + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub1.submodule.graph) + + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub2.submodule.graph) + + # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.graph) + FileCheck() \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.submodule.graph) + + def test_errors(self): + """ + Check errors associated with selective lowering. + """ + # Check error messages thrown when attempting to lower something that is not a ScriptModule. + with self.assertRaisesRegex(RuntimeError, r"Object .* is not a ScriptModule"): + to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) + + MiddleModule = SelectiveLoweringTest.MiddleModule + mod = MiddleModule(BasicModule()) + mod.new_attr = 3 + + with self.assertRaisesRegex(RuntimeError, r"Attribute named new_attr is not a Module"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"]) + + # Check error message thrown when module hierarchy doesn't have unique types. + OuterModule = SelectiveLoweringTest.OuterModule + mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + + with self.assertRaisesRegex(RuntimeError, r"Selective lowering is only supported for module hierarchies with unique types"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]) + + class TestBackends(JitTestCase): """ This class wraps and invokes all subclasses of JitBackendTestCase so that each one @@ -200,19 +386,27 @@ def __init__(self, name): super().__init__(name) self.basic_module_test = BasicModuleTest(name) self.nested_module_test = NestedModuleTest(name) + self.selective_lowering_test = SelectiveLoweringTest(name) def setUp(self): super().setUp() if not TEST_WITH_ROCM: self.basic_module_test.setUp() self.nested_module_test.setUp() + self.selective_lowering_test.setUp() @skipIfRocm def test_execution(self): self.basic_module_test.test_execution() self.nested_module_test.test_execution() + self.selective_lowering_test.test_execution() @skipIfRocm def test_save_load(self): self.basic_module_test.test_save_load() self.nested_module_test.test_save_load() + self.selective_lowering_test.test_save_load() + + @skipIfRocm + def test_errors(self): + self.selective_lowering_test.test_errors() diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 5f16b0229c2d..9c59b00a04cd 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -7,6 +7,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.jit._recursive import wrap_cpp_module +from typing import Any import io @@ -1222,3 +1223,35 @@ def forward(self, cond: bool): mod_eager = Mod() self.assertEqual(mod_eager(True), frozen_mod(True)) self.assertEqual(mod_eager(False), frozen_mod(False)) + + def test_freeze_module_with_non_static_module_dict_index(self): + """ + Test that a Module contained a non-static ModuleDict index + cannot be frozen. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + # Test annotation of submodule. + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = torch.jit.script(Mod()) + m.eval() + with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"): + mf = torch._C._freeze_module(m._c) diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index b53bf10a70c2..e261124bedb5 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -1,7 +1,7 @@ import os import sys -from typing import List +from typing import Any, List, Tuple from collections import OrderedDict import torch import torch.nn as nn @@ -428,3 +428,64 @@ def forward(self, inputs): m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) + + def test_typed_module_dict(self): + """ + Test that a type annotation can be provided for a ModuleDict that allows + non-static indexing. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + class DoesNotImplementInterface(torch.nn.Module): + def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.max(inp, dim=0) + + # Test annotation of submodule. + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = Mod() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test annotation of self. + class ModDict(torch.nn.ModuleDict): + def __init__(self): + super().__init__({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self[key] + return submodule.forward(x) + + m = ModDict() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test error message thrown when annotated attribute does not comply with the + # annotation. + class ModWithWrongAnnotation(torch.nn.ModuleDict): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self.d[key] + return submodule.forward(x) + + with self.assertRaisesRegex(RuntimeError, r"Attribute module is not of annotated type"): + torch.jit.script(ModWithWrongAnnotation()) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 178db8357e8f..23751c4fd92b 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -938,3 +938,12 @@ def forward(self, a): x = torch.tensor([1., 2., 3., 4.]) self.assertTrue(torch.equal(m(x), m2(x))) + + def test_save_nonexit_file(self): + class Foo(torch.nn.Module): + def forward(self, x): + return 2 * x + + script_module = torch.jit.script(Foo()) + with self.assertRaises(RuntimeError): + script_module.save("NonExist/path/test.pt") diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 16558953611b..ca67875b107f 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -170,5 +170,51 @@ def forward(self): r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."): script_module._save_to_buffer_for_lite_interpreter() + def test_unsupported_return_list_with_module_class(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + + class MyTestModuleForListWithModuleClass(torch.nn.Module): + def __init__(self): + super(MyTestModuleForListWithModuleClass, self).__init__() + self.foo = Foo() + + def forward(self): + my_list: List[Foo] = [self.foo] + return my_list + + script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) + with self.assertRaisesRegex(RuntimeError, + r"^Returining a list or dictionary with pytorch class type " + r"is not supported in mobile module " + r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " + r"Workaround\: instead of using pytorch class as their element type\, " + r"use a combination of list\, dictionary\, and single types\.$"): + script_module._save_to_buffer_for_lite_interpreter() + + def test_unsupported_return_dict_with_module_class(self): + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + + class MyTestModuleForDictWithModuleClass(torch.nn.Module): + def __init__(self): + super(MyTestModuleForDictWithModuleClass, self).__init__() + self.foo = Foo() + + def forward(self): + my_dict: Dict[int, Foo] = {1: self.foo} + return my_dict + + script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) + with self.assertRaisesRegex(RuntimeError, + r"^Returining a list or dictionary with pytorch class type " + r"is not supported in mobile module " + r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " + r"Workaround\: instead of using pytorch class as their element type\, " + r"use a combination of list\, dictionary\, and single types\.$"): + script_module._save_to_buffer_for_lite_interpreter() + if __name__ == '__main__': unittest.main() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index f08f772aa8e1..38ddb094794e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -430,6 +430,16 @@ def forward(self, input): m1 = torch.randn(3, 4, 5, 6, 7) self.run_test(MyModel(), m1) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # Need type inference + def test_index_mask_nd(self): + class MyModel(torch.nn.Module): + def forward(self, input): + return input[input > 0] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) + @disableScriptTest() def test_dict(self): class MyModel(torch.nn.Module): @@ -452,6 +462,42 @@ def forward(self, x_in): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x,)) + def test_none_as_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y is not None: + return x + y + return x + + x = torch.randn(2, 3) + self.run_test(Model(), (x, None)) + + def test_none_as_tuple_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y[0] is not None: + return x + y[0] + if y[1] is not None: + return x + y[1] + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(Model(), (x, (None, y))) + + def test_none_as_named_input(self): + class Model(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + z = torch.randn(2, 3) + self.run_test(Model(), (x, None, z)) + @skipIfUnsupportedMinOpsetVersion(9) def test_cste_script(self): class MyModel(torch.jit.ScriptModule): diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py index 480ca18128c5..1ef6e9b89c0c 100644 --- a/test/quantization/test_quantize.py +++ b/test/quantization/test_quantize.py @@ -25,7 +25,8 @@ float_qparams_dynamic_qconfig, PerChannelMinMaxObserver, QConfigDynamic, - default_dynamic_quant_observer + default_dynamic_quant_observer, + FixedQParamsFakeQuantize, ) from torch.testing._internal.common_quantization import ( @@ -1247,6 +1248,36 @@ def forward(self, x): def test_leaky_relu(self): self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) + +class TestEagerModeQATOps(QuantizationTestCase): + def test_fixed_qparam_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.tanh = torch.nn.Tanh() + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.sigmoid(x) + x = self.hardsigmoid(x) + x = self.tanh(x) + x = self.dequant(x) + return x + + m = M().train() + m.qconfig = default_qat_qconfig + m = prepare_qat(m) + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) + m = convert(m) + # make sure activation post process is removed + for attr in ['sigmoid', 'hardsigmoid', 'tanh']: + self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) + class TestFunctionalModule(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out @given(train_mode=st.booleans()) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 4e352a80daa3..34bb78743bc8 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -3,15 +3,20 @@ import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq import torch.multiprocessing as mp # graph mode quantization based on fx -from torch.quantization import ( - QuantType, +from torch.quantization.quantize_fx import ( prepare_fx, convert_fx, prepare_qat_fx, +) + +from torch.quantization import ( + QuantType, + quant_type_to_str, default_qconfig, default_dynamic_qconfig, default_dynamic_quant_observer, @@ -23,6 +28,7 @@ convert, PerChannelMinMaxObserver, QConfigDynamic, + FixedQParamsFakeQuantize, ) # test utils @@ -44,16 +50,132 @@ from torch.testing._internal.common_quantization import NodeSpec as ns -from torch.testing._internal.common_quantization import ( - test_only_eval_fn, -) from torch.testing import FileCheck +import copy import itertools import operator import unittest import io +class TestFuseFx(QuantizationTestCase): + def test_fuse_conv_bn_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.conv1d2 = nn.Conv1d(1, 1, 1) + self.conv2d2 = nn.Conv2d(1, 1, 1) + self.conv3d2 = nn.Conv3d(1, 1, 1) + self.bn1d2 = nn.BatchNorm1d(1) + self.bn2d2 = nn.BatchNorm2d(1) + self.bn3d2 = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.bn1d(x) + x = self.conv2d(x) + x = self.bn2d(x) + x = self.conv3d(x) + x = self.bn3d(x) + x = self.conv1d2(x) + x = self.bn1d2(x) + x = self.relu(x) + x = self.conv2d2(x) + x = self.bn2d2(x) + x = self.relu(x) + x = self.conv3d2(x) + x = self.bn3d2(x) + x = self.relu(x) + return x + + # test train mode + m = M().train() + # currently we don't check if the module are configured with qconfig before fusion + # TODO: if we decide to do that in the future, this test needs to + # be updated + # train mode fuse_fx is called in prepare_qat_fx + m = prepare_qat_fx(m, {}) + expected_nodes = [ + ns.call_module(nni.ConvBn2d), + ns.call_module(nni.ConvBn3d), + ns.call_module(nni.ConvBnReLU2d), + ns.call_module(nni.ConvBnReLU3d), + ] + # ConvBnRelu1d is not fused + expected_occurrence = { + ns.call_module(nn.ReLU): 1 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + # test eval mode + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + # fuse_fx is a top level api and only supports eval mode + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nn.Conv2d), + ns.call_module(nn.Conv3d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ] + # ConvBnRelu1d is not fused + expected_occurrence = { + ns.call_module(nn.ReLU): 1 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + def test_fuse_module_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.relu(x) + x = self.conv2d(x) + x = self.relu(x) + x = self.conv3d(x) + x = self.relu(x) + x = self.bn1d(x) + x = self.relu(x) + x = self.bn2d(x) + x = self.relu(x) + x = self.bn3d(x) + x = self.relu(x) + return x + + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nni.ConvReLU1d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ns.call_module(nni.BNReLU2d), + ns.call_module(nni.BNReLU3d), + ] + self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) + @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): def _get_conv_linear_test_cases(self): @@ -265,32 +387,6 @@ def forward(self, x): model_device = next(iter(model_devices)) self.assertEqual(model_device, device) - @skipIfNoFBGEMM - def test_inplace_option(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - return self.conv(x) - - model = M().eval() - qconfig_dict = {'': default_qconfig} - prepared = prepare_fx( - model, qconfig_dict, inplace=False) - test_only_eval_fn(model, self.img_data_2d) - non_inplace_model = convert_fx(prepared, inplace=True) - - prepared = prepare_fx( - model, qconfig_dict, inplace=True) - test_only_eval_fn(model, self.img_data_2d) - inplace_model = convert_fx(prepared, inplace=True) - - non_inplace_res = non_inplace_model(self.img_data_2d[0][0]) - inplace_res = inplace_model(self.img_data_2d[0][0]) - self.assertEqual(non_inplace_res, inplace_res) - @skipIfNoFBGEMM def test_dict_output(self): """ Make sure quantization runs for models with dictionary output @@ -632,104 +728,126 @@ def test_custom_module_class(self): class CustomModule(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) def forward(self, x): - return self.conv(x) + return self.linear(x) class ObservedCustomModule(torch.nn.Module): - def __init__(self, conv): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_float(cls, float_module): assert hasattr(float_module, 'qconfig') - observed = cls(float_module.conv) + observed = cls(float_module.linear) observed.qconfig = float_module.qconfig return observed - class QuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') assert hasattr(observed_module, 'activation_post_process') - observed_module.conv.activation_post_process = \ + observed_module.linear.activation_post_process = \ observed_module.activation_post_process - quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnq.Linear.from_float(observed_module.linear)) return quantized - class DynamicallyQuantizedCustomModule(torch.nn.Module): - def __init__(self, conv): + class DynamicQuantCustomModule(torch.nn.Module): + def __init__(self, linear): super().__init__() - self.conv = conv + self.linear = linear def forward(self, x): - return self.conv(x) + return self.linear(x) @classmethod def from_observed(cls, observed_module): assert hasattr(observed_module, 'qconfig') - assert hasattr(observed_module, 'activation_post_process') - quantized = cls(nnqd.Conv2d.from_float(observed_module.conv)) + quantized = cls(nnqd.Linear.from_float(observed_module.linear)) return quantized class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 1, 1) + self.linear = torch.nn.Linear(3, 3) self.custom = CustomModule() def forward(self, x): - x = self.conv(x) + x = self.linear(x) x = self.custom(x) return x class RefM(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) + x = self.linear1(x) + x = self.linear2(x) return x - data = torch.randn(1, 1, 1, 1) + data = torch.randn(3, 3) # instantiate M and RefM and align the parameters original_m = M().eval() original_ref_m = RefM().eval() - original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) - original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) - original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) - original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) + original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) + original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) + original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) + original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) + + test_configs = { + "static": (default_qconfig, StaticQuantCustomModule, 3), + "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) + } - # TODO: add other quant types after mixed mode support - for quant_type in [QuantType.STATIC]: - qconfig_dict = { - "": default_qconfig, - } - prepare_custom_config_dict = { - "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + for quant_type in [QuantType.DYNAMIC]: + key = quant_type_to_str(quant_type) + qconfig, quantized_module_class, num_observers = test_configs[key] + qconfig_dict = {"": qconfig} + if key == "static": + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule + } + } } - } - convert_custom_config_dict = { - "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: quantized_module_class + } + } } - } + else: + prepare_custom_config_dict = { + "non_traceable_module_class": [ + CustomModule + ] + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "dynamic": { + CustomModule: quantized_module_class + } + } + } + # check prepared model m = prepare_fx( original_m, @@ -739,7 +857,7 @@ def forward(self, x): m(data) # all activation observers are inserted in the top level module count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 3 + ns.call_module(torch.quantization.MinMaxObserver): num_observers } self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) @@ -747,12 +865,14 @@ def forward(self, x): m = convert_fx( m, convert_custom_config_dict=convert_custom_config_dict) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + if quant_type == QuantType.STATIC: + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Linear) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + self.assertEqual(type(m.custom), quantized_module_class) res = m(data) # quantize the reference model @@ -815,6 +935,33 @@ def forward(self, x): # make sure these modules are not traced self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + def test_prepared_model_deepcopy(self): + """Ensures that copy.deepcopy works correctly on a prepared model. + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self._foobar = 'foobar' + self.foobar2 = 'foobar2' + + def forward(self, x): + x = self.conv(x) + return x + + m = M() + print(m.__dict__.keys()) + m.eval() + qconfig_dict = {'': torch.quantization.default_qconfig} + prepared = prepare_fx(m, qconfig_dict) + # calibrate + prepared(torch.randn(4, 1, 4, 4)) + # copy + prepared_copy = copy.deepcopy(prepared) + # quantize, should run with no errors + quantized = convert_fx(prepared_copy) + + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops @@ -968,7 +1115,10 @@ def __init__(self, is_inplace, is_scalar): def forward(self, x, y): x = self.conv1(x) y = 3 if self.is_scalar else self.conv2(y) + # x = x + y x = self.op(x, y) + # x = y + x + x = self.op(y, x) return x # TODO: decide whether we want to quantize or not @@ -1011,6 +1161,8 @@ def forward(self, x, y): y = 3 if self.is_scalar else self.conv2(y) x = self.op(x, y) x = self.relu(x) + x = self.op(y, x) + x = self.relu(x) return x data = (torch.rand((1, 1, 1, 1), dtype=torch.float), @@ -1410,7 +1562,7 @@ def test_general_value_ops(self): """ class M(torch.nn.Module): def __init__(self): - super(M, self).__init__() + super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) @@ -1418,9 +1570,6 @@ def __init__(self): self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) - self.hardsigmoid = torch.nn.Hardsigmoid() - self.sigmoid = torch.nn.Sigmoid() - self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) @@ -1442,21 +1591,6 @@ def forward(self, x): x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') - x = self.hardsigmoid(x) - x = F.hardsigmoid(x) - x = F.hardsigmoid(x, inplace=True) - x = x.hardsigmoid() - x.hardsigmoid_() - x = self.sigmoid(x) - x = torch.sigmoid(x) - # F.sigmoid is deprecated - x = x.sigmoid() - x.sigmoid_() - x = self.tanh(x) - # F.tanh is deprecated - x = torch.tanh(x) - x = x.tanh() - x.tanh_() x = self.conv(x) return x @@ -1488,6 +1622,84 @@ def forward(self, x): expected_node_occurrence=count_check, expected_node_list=order_check) + @skipIfNoFBGEMM + def test_fixed_qparams_ops(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.tanh = torch.nn.Tanh() + + def forward(self, x): + x = self.conv(x) + # F.sigmoid is deprecated + x = self.sigmoid(x) + x = torch.sigmoid(x) + x = x.sigmoid() + x.sigmoid_() + x = self.hardsigmoid(x) + x = F.hardsigmoid(x) + x = F.hardsigmoid(x, inplace=True) + x = x.hardsigmoid() + x.hardsigmoid_() + x = self.tanh(x) + # F.tanh is deprecated + x = torch.tanh(x) + x = x.tanh() + x.tanh_() + x = self.conv(x) + return x + + for eval_mode in [True, False]: + # This model is not executable since we just put all ops + # in the same forward + m = M() + if eval_mode: + m.eval() + qconfig = default_qconfig + prepare = prepare_fx + fq_count = 0 + else: + m.train() + qconfig = default_qat_qconfig + prepare = prepare_qat_fx + fq_count = 13 + + # nothing to fuse so skipping the fuse step + qconfig_dict = {'': qconfig} + prepared = prepare(m, qconfig_dict) + # check the correct number of activation_post_process is inserted + count_check = { + ns.call_module(FixedQParamsFakeQuantize) : fq_count, + } + self.checkGraphModuleNodes( + prepared, + expected_node_occurrence=count_check) + # not runnable + quantized = convert_fx(prepared) + + # This checks that the dequantize from the output of first conv + # is being propagated to the end, so that we don't insert extra + # observers + # check exact counts of quantize and dequantize + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_method('dequantize') : 1 + } + order_check = [ + ns.call_function(torch.quantize_per_tensor), + ns.call_module(nnq.Conv2d), + ns.call_module(nn.Sigmoid), + ns.call_module(nnq.Conv2d), + ns.call_method('dequantize'), + ] + self.checkGraphModuleNodes( + quantized, + expected_node_occurrence=count_check, + expected_node_list=order_check) + def test_float_functional(self): class TorchAdd(nn.Module): """Wrapper around torch.add so that all ops can be found at build""" diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 36f317529285..96568662c052 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -91,9 +91,9 @@ def pool_output_shape(input_size, kernel_size, padding, stride, output_size = ( (input_size + 2 * padding - dilation * (kernel_size - 1) - 1 + (stride - 1 if ceiling_mode else 0)) // stride + 1) - if (padding > 0 and + if (ceiling_mode and ((output_size - 1) * stride >= input_size + padding)): - output_size += 1 + output_size -= 1 return output_size """ @@ -3611,6 +3611,122 @@ def test_qconv_transpose2d( Y_q = qconv_op(X_q) self.assertEqual(Y_q_ref, Y_q) + """Tests the correctness of quantized convolution op.""" + @given(batch_size=st.integers(1, 3), + input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + time=st.integers(2, 5), + height=st.integers(10, 16), + width=st.integers(7, 14), + output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + groups=st.integers(1, 3), + kernel_t=st.integers(1, 7), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_t=st.integers(1, 2), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_t=st.integers(0, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + o_pad_t=st.integers(0, 2), + o_pad_h=st.integers(0, 2), + o_pad_w=st.integers(0, 2), + dilation=st.integers(1, 2), + X_scale=st.floats(1.2, 1.6), + X_zero_point=st.integers(0, 4), + W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), + W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), + Y_scale=st.floats(4.2, 5.6), + Y_zero_point=st.integers(0, 4), + use_bias=st.booleans()) + @override_qengines + def test_qconv_transpose3d( + self, + batch_size, + input_channels_per_group, + time, + height, + width, + output_channels_per_group, + groups, + kernel_t, + kernel_h, + kernel_w, + stride_t, + stride_h, + stride_w, + pad_t, + pad_h, + pad_w, + o_pad_t, + o_pad_h, + o_pad_w, + dilation, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + Y_scale, + Y_zero_point, + use_bias): + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + assume(o_pad_t < stride_t or o_pad_t < dilation) + assume(o_pad_h < stride_h or o_pad_h < dilation) + assume(o_pad_w < stride_w or o_pad_w < dilation) + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_t, kernel_h, kernel_w) + strides = (stride_t, stride_h, stride_w) + pads = (pad_t, pad_h, pad_w) + o_pads = (o_pad_t, o_pad_h, o_pad_w) + dilations = (dilation, dilation, dilation) + + qconv = torch.ops.quantized.conv_transpose3d + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + conv_op = torch.nn.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + X_q, W_q, bias_float = self._test_qconv_impl( + qconv, qconv_prepack, conv_op, batch_size, + input_channels_per_group, (time, height, width), + output_channels_per_group, groups, kernels, strides, pads, o_pads, + dilations, X_scale, X_zero_point, W_scale, W_zero_point, + Y_scale, Y_zero_point, use_bias, use_relu=False, + use_channelwise=False, use_transpose=True) + + # Test the module implementation + qconv_op = torch.nn.quantized.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + qconv_op.scale = Y_scale + qconv_op.zero_point = Y_zero_point + qconv_op.set_weight_bias(W_q, bias_float) + + Y_dq_ref = conv_op(X_q.dequantize()) + Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, + zero_point=Y_zero_point, + dtype=torch.quint8) + Y_q = qconv_op(X_q) + self.assertEqual(Y_q_ref, Y_q) + @given( inputs=hu.tensor_conv( spatial_dim=1, batch_size_range=(1, 3), @@ -3863,22 +3979,26 @@ def test_qconv3d( stride_w=st.integers(1, 2), pad_d=st.integers(1, 2), pad_h=st.integers(1, 2), pad_w=st.integers(1, 2), - channelwise=st.booleans(), - qengine=st.sampled_from(("fbgemm",))) + o_pad=st.integers(0, 2), + channelwise=st.booleans()) + @override_qengines def test_qconv3d_unpack( - self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, - channelwise, qengine + self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, o_pad, + channelwise ): - if qengine not in supported_qengines: - return - - with override_quantized_engine(qengine): - qconv3d_prepack = torch.ops.quantized.conv3d_prepack - qconv3d_unpack = torch.ops.quantized.conv3d_unpack - self._test_qconv_unpack_impl( - qconv3d_prepack, qconv3d_unpack, inputs, - (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), None, - channelwise) + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + transposed = inputs[-1] + if transposed: + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + qconv_unpack = torch.ops.quantized.conv_transpose3d_unpack + else: + qconv_prepack = torch.ops.quantized.conv3d_prepack + qconv_unpack = torch.ops.quantized.conv3d_unpack + self._test_qconv_unpack_impl( + qconv_prepack, qconv_unpack, inputs, + (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), + channelwise) class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index a16b01d36d46..cd722d59d2a2 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -736,7 +736,7 @@ def test_histogram_observer_same_inputs(self): self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.]) - @given(N=st.sampled_from([10, 1000, 10**6]), + @given(N=st.sampled_from([10, 1000]), bins=st.sampled_from([256, 512, 1024, 2048]), dtype=st.sampled_from([torch.qint8, torch.quint8]), qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]), diff --git a/test/run_test.py b/test/run_test.py index ad4603e809f2..0fa84c00044c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -313,6 +313,11 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit if extra_unittest_args: assert isinstance(extra_unittest_args, list) unittest_args.extend(extra_unittest_args) + + # If using pytest, replace -f with equivalent -x + if options.pytest: + unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args] + # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. argv = [test_module + '.py'] + unittest_args diff --git a/test/test_autograd.py b/test/test_autograd.py index 211834f9e82d..5dc5c94c3e53 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2881,6 +2881,14 @@ def test_pow_scalar_base(self): a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_() gradcheck(lambda a: torch.pow(2, a), (a,)) + def test_igamma(self): + # 1e-3 offset to avoid zeros + # NOTE: derivative for s is not implemented + s = (torch.rand(100, dtype=torch.double) + 1e-3) + x = (torch.rand(100, dtype=torch.double) + 1e-3).requires_grad_() + gradcheck(torch.igamma, (s, x)) + gradgradcheck(torch.igamma, (s, x)) + @skipIfNoLapack def test_pinverse(self): # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? @@ -4917,8 +4925,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, # the tests for these ops which do not have 'complex' in variant should not run for complex # and only run for floating point -# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition -separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan'] +separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'div', 'log', + 'log10', 'log1p', 'log2', 'pow', 'tan', 'reciprocal', 'rsqrt', '__rdiv__'] # NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly # for non-holomorphic functions @@ -4929,16 +4937,12 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', - 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', - 'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests + 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', + 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub'] + separate_complex_tests # this list corresponds to cases that are not currently implemented skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] -# TODO(@anjali411): add tests for 'sub', 'div -# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411 -# complex_list += ['fill_', 't', '__rdiv__', 'tanh'] - def add_test( name, self_size, diff --git a/test/test_cuda.py b/test/test_cuda.py index 577cd5caa36a..49c7848c9122 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -23,7 +23,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, tri_large_tests_args, \ _compare_trilu_indices, _compare_large_trilu_indices from torch.testing._internal.common_utils import TestCase, get_gpu_type, freeze_rng_state, run_tests, \ - NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_SANDCASTLE, \ + NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_SANDCASTLE, IS_WINDOWS, \ slowTest, skipCUDANonDefaultStreamIf, TEST_WITH_ROCM, TEST_NUMPY from torch.testing._internal.autocast_test_lists import AutocastTestLists @@ -492,7 +492,6 @@ def _test_copy_non_blocking(a, b): event = torch.cuda.Event() a.copy_(b, non_blocking=True) event.record() - self.assertFalse(event.query()) event.synchronize() self.assertEqual(a, b) @@ -505,22 +504,35 @@ def _test_copy_non_blocking(a, b): y = torch.ones(10000000, dtype=torch.uint8).cuda() _test_copy_non_blocking(x, y) - @unittest.skip("skipped because test could be flaky, see #35144") def test_to_non_blocking(self): - def _test_to_non_blocking(a, non_blocking): - stream = torch.cuda.current_stream() - with torch.cuda.stream(stream): - b = a.to('cuda', non_blocking=non_blocking) - self.assertEqual(stream.query(), not non_blocking) - stream.synchronize() - self.assertEqual(a, b) + stream = torch.cuda.current_stream() - # 10MB copies - x = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(x, True) + def _test_to_non_blocking(a, non_blocking, dst): + torch.cuda.synchronize() + # Pushes an 0.1 second spin to stream so if the copy is non blocking, + # stream will almost surely be active when we query(). + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + b = a.to(device=dst, non_blocking=non_blocking) + self.assertEqual(stream.query(), not non_blocking) + stream.synchronize() + self.assertEqual(a, b) + self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu")) + + for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)): + # Creates source on the opposite device from destination. + src = torch.randn(1000000, + device="cuda" if dst == "cpu" else "cpu", + pin_memory=True if dst == "cuda" else False) + _test_to_non_blocking(src, try_non_blocking, dst) - y = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(y, False) + def test_to_cpu_blocking_by_default(self): + src = torch.randn(1000000, device="cuda") + torch.cuda.synchronize() + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + dst = src.to(device="cpu") + self.assertEqual(torch.cuda.current_stream().query(), True) + self.assertEqual(src, dst) + self.assertFalse(dst.is_pinned()) def test_serialization_array_with_storage(self): x = torch.randn(5, 5).cuda() @@ -2151,6 +2163,7 @@ def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): self._run_scaling_case(run, unskipped=3, skipped=1) + @unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows') def test_grad_scaling_penalty(self): def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): for i, (input, target) in enumerate(data): diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 67a9c8477e8b..0d6ee2e03bd6 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1564,7 +1564,7 @@ def test_proper_exit(self): # In all cases, all processes should end properly. if use_workers: exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill'] - persistent_workers = self.persistent_workers + persistent_workers = self.persistent_workers else: exit_methods = [None, 'loader_error', 'loader_kill'] persistent_workers = False @@ -1840,6 +1840,12 @@ def test_default_collate_shared_tensor(self): finally: _utils.worker._worker_info = old + def test_excessive_thread_creation_warning(self): + with self.assertWarnsRegex( + UserWarning, + r"excessive worker creation might get DataLoader running slow or even freeze"): + dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) + class StringDataset(Dataset): def __init__(self): diff --git a/test/test_fx.py b/test/test_fx.py index b5d03a86a177..05e5a821f4ef 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -814,6 +814,11 @@ def forward(self, x, w): x, w = torch.rand(3, 4), torch.rand(4, 4) self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) + def test_sequential(self): + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) + gm = torch.fx.symbolic_trace(m) + gm_copy = copy.deepcopy(gm) + def test_ctx_mgr(self): @contextlib.contextmanager def do_nothing(): @@ -838,6 +843,25 @@ def test_typename_print(self): output : torch.fx.Node = graph.output(b) self.assertTrue('typing.List[float]' in str(graph)) + def test_inf_nan(self): + class FooMod(torch.nn.Module): + def forward(self, x): + return x + float('inf'), x + float('-inf'), x + float('nan') + + fm = FooMod() + self.checkGraphModule(fm, (torch.rand(3, 4),)) + + def test_inf_nan_kwds(self): + graph : torch.fx.Graph = torch.fx.Graph() + x : torch.fx.Node = graph.create_node('placeholder', 'x') + b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') + c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') + graph.output((b, c)) + + gm = torch.fx.GraphModule(torch.nn.Module(), graph) + x = torch.rand(3, 4) + self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) + def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self): @@ -1055,6 +1079,26 @@ def foo(x, y): x, y = torch.randn(3, 4), torch.randn(3, 4) self.checkGraphModule(foo, (x, y)) + def test_direct_param_use(self): + class TransposeTest(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.nn.Parameter(torch.rand(4, 3)) + + def forward(self, x): + return self.b + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = TransposeTest() + + def forward(self, x): + return self.a.b, self.a.b.t(), self.a.b.view(12) + + traced = torch.fx.symbolic_trace(Foo()) + assert(all('constant' not in node.target for node in traced.graph.nodes)) + if __name__ == '__main__': run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 8a6fbe4a3a90..d5d3daef722b 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,7 +1,7 @@ import torch from torch.fx.symbolic_trace import symbolic_trace from torch.fx.experimental import GraphManipulation -from torch.fx.experimental.Partitioner import Partitioner, Device +from torch.fx.experimental.Partitioner import Partitioner, Device, PartitionerConfig from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase @@ -20,13 +20,16 @@ def forward(self, a, b): ) partitioner = Partitioner() devices = [ - Device('dev_0', 125), - Device('dev_1', 125), - Device('dev_2', 125) + Device('dev_0', 125, 0), + Device('dev_1', 125, 1), + Device('dev_2', 125, 2) ] - ret = partitioner.partition_graph(traced, m, devices) + partitioner_config = PartitionerConfig(devices) + ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules + dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) + assert dag.nodes[0].logical_device_ids == [0] def test_size_based_partition(self): class TestModule(torch.nn.Module): @@ -51,45 +54,113 @@ def forward(self, a, b): ) partitioner = Partitioner() devices = [ - Device('dev_0', 125), - Device('dev_1', 125), - Device('dev_2', 125) + Device('dev_0', 125, 0), + Device('dev_1', 125, 1), + Device('dev_2', 125, 2) ] - ret = partitioner.partition_graph(traced, m, devices) + partitioner_config = PartitionerConfig(devices) + ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules + dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert len(module_with_submodules.graph.nodes) == 7 + for i, node in enumerate(dag.nodes): + assert node.logical_device_ids == [i] def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() - self.linear_0 = torch.nn.Linear(4, 4) + self.linear = torch.nn.Linear(4, 4) - def forward(self, a, b): + def forward(self, a): + b = torch.rand(4) add_1 = a + b - c = self.linear_0(a) - add_2 = c + add_1 - return add_2 + linear_1 = self.linear(add_1) + add_2 = torch.rand(4) + a + add_3 = add_2 + linear_1 + return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) - b = torch.rand(4) GraphManipulation.get_size_of_all_nodes( traced, - [a, b] + [a] ) partitioner = Partitioner() devices = [ - Device('dev_0', 125), - Device('dev_1', 125), - Device('dev_2', 125) + Device('dev_0', 120, 0), + Device('dev_1', 144, 1) ] - ret = partitioner.partition_graph(traced, m, devices) + partitioner_config = PartitionerConfig(devices, is_sparse_nn=False) + ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert len(module_with_submodules.graph.nodes) == 5 + dag = ret.dag + self.assertEqual(traced(a), module_with_submodules(a)) + assert dag.nodes[0].logical_device_ids == [0] + assert dag.nodes[0].size_bytes == 80 + assert dag.nodes[1].logical_device_ids == [1] + assert dag.nodes[1].size_bytes == 144 + + def test_sparse_nn_partition(self): + class MyRecommendationModule(torch.nn.Module): + def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): + layers = torch.nn.ModuleList() + for _ in range(num_of_layers): + ll = torch.nn.Linear(input_size, output_size) + layers.append(ll) + layers.append(torch.nn.ReLU()) + return layers + + def __init__(self): + super(MyRecommendationModule, self).__init__() + layers = self.create_mlp(4, 4, 4) + self.bottom_layers = torch.nn.Sequential(*layers) + layers = self.create_mlp(3, 24, 24) + self.top_layers = torch.nn.Sequential(*layers) + self.embedding_layers = torch.nn.ModuleList() + el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) + self.embedding_layers.append(el) + for i in range(3): + el = torch.nn.EmbeddingBag(1000000, 4, mode='sum', sparse=True) + self.embedding_layers.append(el) + el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) + self.embedding_layers.append(el) + + def forward(self, a, b, offset): + x = self.bottom_layers(a) + y = [] + c = [] + for i in range(len(self.embedding_layers)): + temp = torch.randint(10, (8, )) + c.append(temp + b) + for i in range(len(self.embedding_layers)): + if i % 2 == 0: + y.append(self.embedding_layers[i](c[i], offset)) + else: + y.append(self.embedding_layers[i](torch.randint(10, (8, )), offset)) + z = torch.cat([x] + y, dim=1) + p = self.top_layers(z) + return p + + m = MyRecommendationModule() + a = torch.rand(2, 4) + b = torch.randint(10, (8, )) + offset = torch.randint(1, (2, )) + traced = symbolic_trace(m) + GraphManipulation.get_size_of_all_nodes(traced, [a, b, offset]) + devices = [ + Device('dev_0', 33000000, 0), + Device('dev_1', 33000000, 1), + Device('dev_2', 33000000, 2) + ] + partitioner_config = PartitionerConfig(devices, is_sparse_nn=True) + partitioner = Partitioner() + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + dag = ret.dag + self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) + assert len(module_with_submodules.graph.nodes) == 24 if __name__ == '__main__': run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index c5c4fe52d724..f4c3a8d95b69 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -82,7 +82,7 @@ from itertools import product import itertools from textwrap import dedent -from typing import List, Dict, Optional, Tuple, Union +from typing import List, Dict, NamedTuple, Optional, Tuple, Union import inspect import math import functools @@ -13297,7 +13297,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("None = prim::Constant") \ + FileCheck().check_count("prim::Closure_0", 2).check("None = prim::Constant") \ .check_next("return").run(g) code = dedent(''' @@ -13314,7 +13314,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("int = prim::If") \ + FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ .run(g) code = dedent(''' @@ -13328,9 +13328,9 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) fc = FileCheck() - fc.check("prim::Function").check("(Tensor, None) = prim::TupleConstruct") + fc.check("prim::Closure").check("(Tensor, None) = prim::TupleConstruct") # Loop then two if's added in exit transform - fc.check("prim::Function").check("prim::Loop").check_count("prim::If", 2) + fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) fc.run(cu.loop_in_closure.graph) code = dedent(''' @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_namedtuple_type_inference(self): + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) + + def test_check_named_tuple_value(): + named_tuple = _AnnotatedNamedTuple(1) + return named_tuple.value + + self.checkScript(test_check_named_tuple_value, ()) + + def test_error(): + return _UnannotatedNamedTuple(1) + + with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " + r"for argument \'value\' but instead found type \'int\'."): + torch.jit.script(test_error) + def test_isinstance_dynamic(self): @torch.jit.script def foo(a): @@ -15631,7 +15648,7 @@ def add_autograd_test( # Disable complex tests # TODO: Add complex support for jit - if 'complex' in variant_name or name in ['view_as_complex', 'complex']: + if 'complex' in variant_name or name in ['view_as_complex', 'complex', 'angle']: return # Skips aliases, which are tested in test_op_aliases.py diff --git a/test/test_kernel_launch_checks.py b/test/test_kernel_launch_checks.py new file mode 100644 index 000000000000..8796b9913f73 --- /dev/null +++ b/test/test_kernel_launch_checks.py @@ -0,0 +1,42 @@ +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches + + +class AlwaysCheckCudaLaunchTest(TestCase): + def test_check_code(self): + """Verifies that the regex works for a few different situations""" + + # Try some different spacings + self.assertEqual(2, check_code_for_cuda_kernel_launches(""" +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +TORCH_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); + +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +TORCH_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +some_other_stuff; +some_function_call<<<1,2,0,stream>>>(arg1,arg2,arg3); +TORCH_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>> (arg1,arg2,arg3); +TORCH_CUDA_KERNEL_LAUNCH_CHECK(); +some_function_call<<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ; + + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + """)) + + # Does it work for macros? + self.assertEqual(0, check_code_for_cuda_kernel_launches(""" +#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ; \\ + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); + """)) + + def test_check_cuda_launches(self): + check_cuda_kernel_launches() + # TODO: Enable this after warning messages have been dealt with. + self.assertTrue(True) + # self.assertTrue(check_cuda_kernel_launches() == 0) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_linalg.py b/test/test_linalg.py index b518462ef1cd..56c764e7fea1 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -8,7 +8,8 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + (instantiate_device_type_tests, dtypes, dtypesIfCUDA, + onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck @@ -304,7 +305,7 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) self.assertEqual(result_converted, result_out_converted, msg=msg) ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] - ord_matrix = [1, -1, 2, -2, inf, -inf, None] + ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None] S = 10 test_cases = [ ((S, ), ord_vector), @@ -332,13 +333,6 @@ def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype) with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'): torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result) - # TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test - # and add 'nuc' and 'fro' to ord_matrix above - for ord in ['nuc', 'fro']: - input = torch.randn(10, 10, device=device) - with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"): - torch.linalg.norm(input, ord, dtype=torch.float) - # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that # their vector norm results match @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @@ -1023,6 +1017,126 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve(self, device, dtype): + def run_test(a_shape, dims): + a = torch.randn(a_shape, dtype=dtype, device=device) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + def test_tensorsolve_empty(self, device, dtype): + # Check for empty inputs. NumPy does not work for these cases. + a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) + b = torch.empty(a.shape[:2], dtype=dtype, device=device) + x = torch.linalg.tensorsolve(a, b) + self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) + + # TODO: once "solve_cuda" supports complex dtypes, they shall be added to above tests + @unittest.expectedFailure + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.cfloat, torch.cdouble) + def test_tensorsolve_xfailed(self, device, dtype): + a_shape = (2, 3, 6) + a = torch.randn(a_shape, dtype=dtype, device=device) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + result = torch.linalg.tensorsolve(a, b) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(result, expected) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve_non_contiguous(self, device, dtype): + def run_test_permuted(a_shape, dims): + # check for permuted / transposed inputs + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a.movedim((0, 2), (-2, -1)) + self.assertFalse(a.is_contiguous()) + b = torch.randn(a.shape[:2], dtype=dtype, device=device) + b = b.t() + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + def run_test_skipped_elements(a_shape, dims): + # check for inputs with skipped elements + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a[::2] + self.assertFalse(a.is_contiguous()) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + b = b[::2] + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check non-contiguous out + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_permuted(a_shape, d) + + a_shapes = [(4, 3, 6), (6, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_skipped_elements(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32) + def test_tensorsolve_errors_and_warnings(self, device, dtype): + # tensorsolve expects the input that can be reshaped to a square matrix + a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + b = torch.randn(8, 4) + self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) + with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): + torch.linalg.tensorsolve(a, b) + + # if non-empty out tensor with wrong shape is passed a warning is given + out = torch.empty_like(a) + b = torch.randn(6, 4) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.tensorsolve(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.tensorsolve(a, b, out=out) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/test/test_nn.py b/test/test_nn.py index 784786196914..d14bb5aa9963 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3055,6 +3055,13 @@ def test_embedding_functional(self): res_F = F.embedding(a, embeddings) self.assertEqual(res_old, res_F) + embed_old = torch.nn.Embedding(4, 3) + embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) + res_old = embed_old(a) + res_F = F.embedding(a, embeddings, padding_idx=2) + + self.assertEqual(res_old, res_F) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' ' with instruction set support avx2 or newer.') @@ -6503,27 +6510,31 @@ def test_PReLU_backward_requires_grad_false(self): @unittest.skipIf( not TEST_NUMPY or not TEST_SCIPY, "Numpy or Scipy not found") def test_gelu(self): - def _test_gelu(n, m, dtype, contiguous): + def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): + numpy_dtype = { + torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double + }[dtype] + devices = ['cpu'] if dtype != torch.bfloat16 else [] + \ + ['cuda'] if TEST_CUDA else [] + def _gelu_ref(X): return X * stats.norm.cdf(X) - if contiguous: - X = torch.rand(n, m, dtype=dtype, requires_grad=True) - else: - X = torch.rand(n, m, dtype=dtype, requires_grad=True)[:, ::2] - res = F.gelu(X) - ref = _gelu_ref(X.detach().numpy()) - self.assertEqual(res, ref) - gradcheck(F.gelu, [X], eps=1e-4) - - if TEST_CUDA: - X_cuda = X.cuda() - res_cuda = F.gelu(X_cuda) - self.assertEqual(res_cuda.cpu(), ref) - gradcheck(F.gelu, [X_cuda], eps=1e-4) + for d in devices: + if contiguous: + X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d) + else: + X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2] + res = F.gelu(X) + ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy()) + self.assertEqual(res, ref, rtol=rtol, atol=atol) + if dtype != torch.bfloat16: + gradcheck(F.gelu, [X], eps=1e-4) for n in range(1, 10): for m in range(1, 10): + _test_gelu(n, m, torch.bfloat16, True, 1e-2, 0) + _test_gelu(n, m, torch.bfloat16, False, 1e-2, 0) _test_gelu(n, m, torch.float32, True) _test_gelu(n, m, torch.float32, False) _test_gelu(n, m, torch.float64, True) @@ -10511,6 +10522,27 @@ def test(nonlinearity, *args, **kwargs): test('threshold', 3, 2) test('threshold', 3, 2, inplace=True) + def test_pooling_shape(self, device): + ''' Test the output shape calculation for pooling functions ''' + + # Checks output shape against expected for 1D, 2D and 3D + def check(expected_out_shape, sizes, *args, **kwargs): + for kernel in ['max', 'avg']: + for i in [1, 2, 3]: + if hasattr(torch.nn.functional, f'{kernel}_pool{i}d'): + op = getattr(torch.nn.functional, f'{kernel}_pool{i}d') + t = torch.randn(sizes[:i + 2], device=device) + self.assertEqual(op(t, *args, **kwargs).shape, expected_out_shape[:i + 2]) + + check((1, 1, 3, 3, 4), (1, 1, 5, 6, 7), kernel_size=1, stride=2, padding=0, ceil_mode=True) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=False) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=True) + + # Test case from issue https://github.com/pytorch/pytorch/issues/45357 + x = torch.randn(1, 1, 6, 7, device=device) + y = torch.nn.functional.max_pool2d(x, 1, stride=(2, 2), padding=0, ceil_mode=True) + self.assertEqual(y.size(), (1, 1, 3, 4)) + @onlyOnCPUAndCUDA # TODO: fix on XLA def test_adaptive_avg_pool2d_output_size_one(self, device): def helper(size, memory_format): @@ -10734,6 +10766,15 @@ def fn(weight): fn = fn_wrapper(device) _assertGradAndGradgradChecks(self, fn, (weight, )) + def fn_wrapper(device): + def padding_fn(weight): + inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device) + return torch.nn.functional.embedding(inp, weight, padding_idx=1) + return padding_fn + + fn = fn_wrapper(device) + _assertGradAndGradgradChecks(self, fn, (weight, )) + def test_embedding_scalar_weight_error(self, device): indices = torch.rand(2, 2, device=device).long() weight = torch.tensor(1.0, device=device) @@ -10830,6 +10871,8 @@ def test_embedding_padding_idx(self, device, dtype): embedding.zero_grad() self.assertEqual(after, pre) + # Test fails on Vg20 + @skipCUDAIfRocm @dtypesIfCUDA(torch.half, torch.float) @dtypes(torch.float) def test_softmax_results(self, device, dtype): @@ -11429,6 +11472,8 @@ def test_embedding_max_norm_device(self, device, dtype): self.assertEqual(output[1], output[2]) self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) + # Test fails on Vg20 + @skipCUDAIfRocm @onlyCUDA @dtypes(torch.half, torch.float) def test_softmax(self, device, dtype): diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py index 8a106d7860d1..a6d37c9d52e2 100644 --- a/test/test_op_aliases.py +++ b/test/test_op_aliases.py @@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, skipCPUIfNoLapack, skipCUDAIfNoMagma, onlyCPU) +import collections # Information for generating an alias test # NOTE: ending the alias_name with an underscore will interpret the test @@ -150,6 +151,8 @@ def __init__(self, AliasInfo('true_divide_', torch.Tensor.true_divide_, 'div_', torch.Tensor.div_, lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,), decorators=(onlyCPU,)), + AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack, + lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))), ) # Placeholder test class for validating that aliases are correctly @@ -157,6 +160,14 @@ def __init__(self, class TestOpNormalization(JitTestCase): pass + +# Clone input tensor and sequence of Tensors +def clone_inp(inp): + if isinstance(inp, collections.Sequence): + return list(map(torch.clone, inp)) + else: + return inp.clone() + # Generates alias tests and adds them to the specified class (cls) def create_alias_tests(cls): for info in alias_infos: @@ -180,10 +191,18 @@ def _fn(t): arg_string = ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(alias_name=info.alias_name, args=arg_string) else: - fn_template = ''' - def _fn(t): + is_input_tensor_list = isinstance(info.get_input(device), collections.Sequence) + # For sequence of Tensors, annotate the type to be List[Tensor] + if is_input_tensor_list: + fn_template = ''' + def _fn(t: List[Tensor]): return op(t{args}) - ''' + ''' + else: + fn_template = ''' + def _fn(t): + return op(t{args}) + ''' arg_string = ", " + ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(args=arg_string) @@ -192,8 +211,8 @@ def _fn(t): # Acquires and checks the graph remaps the alias inp = info.get_input(device) - scripted(inp.clone()) - graph = scripted.graph_for(inp.clone()) + scripted(clone_inp(inp)) + graph = scripted.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Checks that tracing converts aliases @@ -203,9 +222,9 @@ def _fn(t): def _fn(t, info=info, args=args): return info.alias_op(t, *args) - traced = torch.jit.trace(_fn, (inp.clone(),)) - traced(inp.clone()) - graph = traced.graph_for(inp.clone()) + traced = torch.jit.trace(_fn, (clone_inp(inp),)) + traced(clone_inp(inp)) + graph = traced.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Applies decorators @@ -223,10 +242,10 @@ def _test_alias_computation(self, device, info=info): inp = info.get_input(device) args = info.get_args(device) - alias_input = inp.clone() + alias_input = clone_inp(inp) alias_result = alias_op(alias_input, *args) - original_input = inp.clone() + original_input = clone_inp(inp) original_result = alias_op(original_input, *args) self.assertEqual(alias_input, original_input, atol=0, rtol=0) diff --git a/test/test_ops.py b/test/test_ops.py index 5be450d4d41f..1d85f86113e9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -89,41 +89,35 @@ def _gradgrad_test_helper(self, device, dtype, op, variant): return self._check_helper(device, dtype, op, variant, 'gradgradcheck') # Tests that gradients are computed correctly - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_fn_grad(self, device, dtype, op): self._grad_test_helper(device, dtype, op, op.get_op()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_method_grad(self, device, dtype, op): self._grad_test_helper(device, dtype, op, op.get_method()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_inplace_grad(self, device, dtype, op): if not op.test_inplace_grad: self.skipTest("Skipped! Inplace gradcheck marked to skip.") self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace())) - # TODO(@anjali411) enable this for torch.cdouble. # Test that gradients of gradients are computed correctly - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_fn_gradgrad(self, device, dtype, op): self._gradgrad_test_helper(device, dtype, op, op.get_op()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_method_gradgrad(self, device, dtype, op): self._gradgrad_test_helper(device, dtype, op, op.get_method()) - # TODO(@anjali411) enable this for torch.cdouble. - @dtypes(torch.double) + @dtypes(torch.double, torch.cdouble) @ops(op_db) def test_inplace_gradgrad(self, device, dtype, op): if not op.test_inplace_grad: diff --git a/test/test_quantization.py b/test/test_quantization.py index 8a6f05cb19de..682d9baff68e 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -44,6 +44,7 @@ from quantization.test_quantize import TestPostTrainingDynamic # noqa: F401 from quantization.test_quantize import TestQuantizationAwareTraining # noqa: F401 from quantization.test_quantize import TestEagerModeOps # noqa: F401 +from quantization.test_quantize import TestEagerModeQATOps # noqa: F401 # TODO: merge with other tests in test_quantize.py? from quantization.test_quantize import TestFunctionalModule # noqa: F401 @@ -60,11 +61,17 @@ from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401 # 3. GraphModule based graph mode quantization -from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +try: + from quantization.test_quantize_fx import TestFuseFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +except ImportError: + # In FBCode we separate FX out into a separate target for the sake of dev + # velocity. These are covered by a separate test target `quantization_fx` + pass -# Tooling: numric_suite +# Tooling: numeric_suite from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401 # Backward Compatibility diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index b0777c7fa12a..7e25b80851c0 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -702,6 +702,23 @@ def test_eye(self, device): for dtype in torch.testing.get_all_dtypes(): if dtype == torch.bfloat16: continue + # Test the RuntimeError is raised when either m or n is a negative number + for n, m in ((-1, 1), (1, -1), (-1, -1)): + with self.assertRaisesRegex(RuntimeError, 'must be greater or equal to'): + torch.eye(n, m, device=device, dtype=dtype) + + # Test when the `m` parameter is not provided + for n in (3, 5, 7): + res1 = torch.eye(n, device=device, dtype=dtype) + naive_eye = torch.zeros(n, n, dtype=dtype, device=device) + naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1) + self.assertEqual(naive_eye, res1) + + # Check eye_out outputs + res2 = torch.empty(0, device=device, dtype=dtype) + torch.eye(n, out=res2) + self.assertEqual(res1, res2) + for n, m in product([3, 5, 7], repeat=2): # Construct identity using diagonal and fill res1 = torch.eye(n, m, device=device, dtype=dtype) diff --git a/test/test_torch.py b/test/test_torch.py index 96ac4872289f..d927446140cf 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -26,7 +26,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, run_additional_tri_tests, \ _compare_trilu_indices from torch.testing._internal.common_utils import \ - (TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, run_tests, + (TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ASAN, TEST_WITH_ROCM, run_tests, skipIfNoLapack, suppress_warnings, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, @@ -3468,6 +3468,25 @@ def test_pin_memory(self): self.assertIs(pinned, pinned.pin_memory()) self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) + def test_new_methods_requires_grad(self): + size = (10,) + test_cases = [ + # method name, args + ('new_full', [size, 1]), + ('new_empty', [size]), + ('new_zeros', [size]), + ] + for method_name, args in test_cases: + x = torch.randn(size) + for requires_grad in [True, False]: + x_new = x.__getattribute__(method_name)(*args, requires_grad=requires_grad) + self.assertEqual(x_new.requires_grad, requires_grad) + x = torch.randint(10, size) + with self.assertRaisesRegex( + RuntimeError, + r'Only Tensors of floating point and complex dtype can require gradients'): + x_new = x.__getattribute__(method_name)(*args, requires_grad=True) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_numpy_unresizable(self) -> None: x = np.zeros((2, 2)) @@ -5300,7 +5319,10 @@ def test_complex_assert_raises(self, device): self.assertRaises(RuntimeError, lambda: zeros.index_add(0, torch.arange(0, size[0], dtype=torch.long, device=device), tensor)) - self.assertRaises(RuntimeError, lambda: torch.sign(torch.tensor([4j], device=device, dtype=dtype))) + with self.assertRaisesRegex(RuntimeError, + (r'Unlike NumPy, torch.sign is not intended to support complex numbers\. ' + r'Please use torch.sgn instead\.')): + torch.sign(torch.tensor([4j], device=device, dtype=dtype)) a = torch.rand((2, 2), dtype=dtype, device=device) b = torch.rand((2, 2), dtype=dtype, device=device) @@ -6326,6 +6348,25 @@ def test_heaviside(self, device, dtypes): with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): input.heaviside_(values) + @onlyCUDA + def test_heaviside_cross_device(self, device): + x = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') + y = torch.tensor(0) + result = torch.heaviside(x, y) + expect = torch.tensor([0, 1, 0, 1, 0, 1], device='cuda') + self.assertEqual(result, expect) + + result = torch.heaviside(y, x) + expect = torch.tensor([-9, 5, 0, 6, -2, 2], device='cuda') + self.assertEqual(result, expect) + + x = torch.tensor([-9, 5, 0, 6, -2, 2]) + y = torch.tensor(0, device='cuda') + with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + torch.heaviside(x, y) + + with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + torch.heaviside(y, x) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), @@ -8096,21 +8137,26 @@ def test_complex_rot90(self, device, dtype): self.compare_with_numpy(torch_fn, np_fn, data) @onlyOnCPUAndCUDA + @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_signal_window_functions(self, device): + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) + @dtypesIfCPU(torch.float, torch.double, torch.long) + def test_signal_window_functions(self, device, dtype): def test(name, kwargs): torch_method = getattr(torch, name + '_window') + if not dtype.is_floating_point: + with self.assertRaisesRegex(RuntimeError, r'floating point'): + torch_method(3, dtype=dtype) + return for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: for periodic in [True, False]: - res = torch_method(size, periodic=periodic, **kwargs, device=device) - # NB: scipy always returns a float32 result + res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) + # NB: scipy always returns a float64 result ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) self.assertEqual(res, ref, exact_dtype=False) with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): torch_method(3, layout=torch.sparse_coo) - with self.assertRaisesRegex(RuntimeError, r'floating point'): - torch_method(3, dtype=torch.long) self.assertTrue(torch_method(3, requires_grad=True).requires_grad) self.assertFalse(torch_method(3).requires_grad) @@ -10490,6 +10536,23 @@ def check(op, a, args, key): check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) + # Discontiguous and strided tensors + a = torch.arange(12, device=device) + self.assertEqual(a[::2].median(), torch.tensor(4, device=device)) + self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device)) + + a.resize_(3, 4) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device)) + self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device)) + + a.resize_(2, 3, 2) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + @onlyOnCPUAndCUDA @dtypes(torch.float, torch.double) @@ -13644,6 +13707,8 @@ def test_binary_op_mem_overlap(self, device, dtype): ("atan2", True, True, 'cuda'), ("hypot", True, True, 'cpu'), ("hypot", True, True, 'cuda'), + ("igamma", True, True, 'cpu'), + ("igamma", True, True, 'cuda'), ("nextafter", True, True, 'cpu'), ("nextafter", True, True, 'cuda'), ("le", True, True, 'cpu'), @@ -15493,7 +15558,8 @@ def test_orgqr_errors(self, device): ((10,), (2,), r"'input' should be 2 dimensional"), ((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"), ((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"), - ((0, 0), (0,), r"'input' should not be empty") + ((0, 0), (0,), r"'input' should not be empty"), + ((2, 2), (2, 0,), r"'tau' should not be empty") ] for a_size, tau_size, error_regex in test_cases: a = torch.rand(*a_size, device=device) @@ -17013,6 +17079,8 @@ def test_exp_slow(self, device, dtype): b = torch.exp(torch.ones(1, dtype=dtype, device=device)) self.assertEqual(a, b.expand(2 ** 31)) + @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16) @dtypes(torch.float, torch.double) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_hardswish(self, device, dtype): @@ -17020,7 +17088,6 @@ def test_hardswish(self, device, dtype): expectedOutput = np.multiply( inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0) - precision_4dps = 0.0002 inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) expectedOutputTensor = \ @@ -17028,14 +17095,12 @@ def test_hardswish(self, device, dtype): # normal self.assertEqual(torch.nn.functional.hardswish(inputTensor), - expectedOutputTensor, - atol=precision_4dps, rtol=0) + expectedOutputTensor) # inplace inputTensorCpy = inputTensor.clone().detach() torch.nn.functional.hardswish(inputTensorCpy, inplace=True) - self.assertEqual(inputTensorCpy, expectedOutputTensor, - atol=precision_4dps, rtol=0) + self.assertEqual(inputTensorCpy, expectedOutputTensor) @onlyCPU @dtypes(torch.float, torch.double) @@ -17049,6 +17114,8 @@ def test_sigmoid(self, device, dtype): torch.tensor(expectedOutput, dtype=dtype, device=device), atol=precision_4dps, rtol=0) + @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16) @dtypes(torch.float, torch.double) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_hardsigmoid(self, device, dtype): @@ -17056,18 +17123,15 @@ def test_hardsigmoid(self, device, dtype): expectedOutput = np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) - precision_4dps = 0.0002 # normal self.assertEqual(torch.nn.functional.hardsigmoid(inputTensor), - torch.tensor(expectedOutput, dtype=dtype, device=device), - atol=precision_4dps, rtol=0) + torch.tensor(expectedOutput, dtype=dtype, device=device)) # inplace inputTensorCpy = inputTensor.clone().detach() self.assertEqual(torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), - torch.tensor(expectedOutput, dtype=dtype, device=device), - atol=precision_4dps, rtol=0) + torch.tensor(expectedOutput, dtype=dtype, device=device)) @skipIfNoSciPy @dtypes(torch.float, torch.double) @@ -17428,6 +17492,70 @@ def test_hypot(self, device, dtype): expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) self.assertEqual(actual, expected) + def _helper_test_igamma(self, loglo, loghi, device, dtype): + exp1 = 2.71828182846 + vec1 = torch.logspace(loglo, loghi, steps=500, base=exp1, + dtype=torch.float64, device=device).unsqueeze(-1) + vec1 = vec1.to(dtype) + inputs = [ + (vec1, vec1.transpose(0, 1)), + (vec1, vec1), # for large number, it should approach 0.5 + (vec1, 0.5 * vec1), # test for considerable ratio + (vec1, 2.0 * vec1), + (vec1[::2, :], vec1[::2, :]), # contiguous/discontiguous tests + (vec1[::2, :], vec1[:vec1.shape[0] // 2, :]), + (vec1[:vec1.shape[0] // 2, :], vec1[::2, :]), + ] + half_prec = dtype in [torch.bfloat16, torch.float16] + for input0, input1 in inputs: + actual = torch.igamma(input0, input1) + if half_prec: + input0 = input0.to(torch.float) + input1 = input1.to(torch.float) + expected = scipy.special.gammainc(input0.cpu().numpy(), input1.cpu().numpy()) + expected = torch.from_numpy(expected).to(dtype) + self.assertEqual(actual, expected) + + @skipCUDAIfRocm # see issue https://github.com/pytorch/pytorch/issues/46531 + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @onlyOnCPUAndCUDA + def test_igamma_common(self, device, dtype): + # test igamma for reasonable range of values + loglo = -4 # approx 0.018 + loghi = 4 # approx 54.6 + self._helper_test_igamma(loglo, loghi, device, dtype) + + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @onlyOnCPUAndCUDA + def test_igamma_edge_cases(self, device, dtype): + tkwargs = {"dtype": dtype, "device": device} + infs = torch.zeros((3,), **tkwargs) + float("inf") + zeros = torch.zeros((3,), **tkwargs) + ones = torch.ones((3,), **tkwargs) + zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + nans = torch.zeros((3,), **tkwargs) + float("nan") + inpouts = [ + # (a , x), out + ((zeros, small_to_inf), ones), + ((small_to_inf, zeros), zeros), + ((infs, zero_to_large), zeros), + ((zero_to_large, infs), ones), + ((zeros, zeros), nans), + ((infs, infs), nans), + ((-small_to_inf, small_to_inf), nans), + ] + for inputs, output in inpouts: + input0, input1 = inputs + calc = torch.igamma(input0, input1) + if torch.all(torch.isnan(output)): + self.assertTrue(torch.all(torch.isnan(calc))) + else: + self.assertEqual(calc, output) + @dtypes(torch.int64, torch.float64) def test_remainder_edge_cases(self, device, dtype): # Test variations of negative values used as input @@ -17451,8 +17579,8 @@ def test_remainder_edge_cases(self, device, dtype): r = a.remainder(b) r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2] * 10000, dtype=dtype, device=device) self.assertEqual(r, r_expected) - # Test nan cases + a = torch.tensor([-34, 0, 34] * 20000, dtype=dtype, device=device) b = torch.zeros(3 * 20000, dtype=dtype, device=device) self.assertTrue(torch.isnan(a.remainder(b)).all()) @@ -19150,8 +19278,12 @@ def verify_against_numpy(t): def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype): # Test error for non-tuple argument + t = torch.randn(10) with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): - torch_fn(torch.randn(10)) + torch_fn(t) + # Test error for a single array + with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): + torch_fn((t)) # Test 0-D num_tensors = random.randint(1, 5) @@ -19201,25 +19333,41 @@ def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + torch.testing.get_all_complex_dtypes())) - def test_hstack(self, device, dtype): - self._test_special_stacks(1, 1, torch.hstack, np.hstack, device, dtype) + def test_hstack_column_stack(self, device, dtype): + ops = ((torch.hstack, np.hstack), (torch.column_stack, np.column_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(1, 1, torch_op, np_op, device, dtype) + + # Test torch.column_stack with combinations of 1D and 2D tensors input + one_dim_tensor = torch.arange(0, 10).to(dtype=dtype, device=device) + two_dim_tensor = torch.arange(0, 100).to(dtype=dtype, device=device).reshape(10, 10) + inputs = two_dim_tensor, one_dim_tensor, two_dim_tensor, one_dim_tensor + torch_result = torch.column_stack(inputs) + + np_inputs = [input.cpu().numpy() for input in inputs] + np_result = np.column_stack(np_inputs) + + self.assertEqual(np_result, + torch_result) @onlyOnCPUAndCUDA @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + torch.testing.get_all_complex_dtypes())) - def test_vstack(self, device, dtype): - self._test_special_stacks(0, 2, torch.vstack, np.vstack, device, dtype) - for i in range(5): - # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) - n = random.randint(1, 10) - input_a = self._generate_input((n,), dtype, device, with_extremal=False) - input_b = self._generate_input((1, n), dtype, device, with_extremal=False) - torch_input = [input_a, input_b] - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch.vstack(torch_input) - expected = np.vstack(np_input) - self.assertEqual(actual, expected) + def test_vstack_row_stack(self, device, dtype): + ops = ((torch.vstack, np.vstack), (torch.row_stack, np.row_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(0, 2, torch_op, np_op, device, dtype) + for i in range(5): + # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) + n = random.randint(1, 10) + input_a = self._generate_input((n,), dtype, device, with_extremal=False) + input_b = self._generate_input((1, n), dtype, device, with_extremal=False) + torch_input = [input_a, input_b] + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch_op(torch_input) + expected = np_op(np_input) + self.assertEqual(actual, expected) @onlyOnCPUAndCUDA @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @@ -19315,6 +19463,17 @@ def compare_helper_(like_fn, t): tp = t.permute(p) compare_helper_(like_fn, tp) + @unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN") + @dtypes(*torch.testing.get_all_dtypes()) + def test_muldiv_scalar(self, device, dtype): + x = make_tensor((10, 3), device, dtype, low=None, high=None) + s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item() + y = torch.full_like(x, s) + self.assertEqual(x * s, x * y) + self.assertEqual(s * x, y * x) + self.assertEqual(x / s, x / y) + self.assertEqual(s / x, y / x) + # Tests that compare a device's computation with the (gold-standard) CPU's. class TestDevicePrecision(TestCase): exact_dtype = True diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 7f10915a5ac4..c023553de402 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -919,8 +919,8 @@ def test_unary_op_out_casting(self, device, dtypes): t = torch.tensor((1), dtype=dtypes[0], device=device) out = torch.empty(0, dtype=dtypes[1], device=device) - ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf, torch.log) - float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf, torch.log} + ops = (torch.neg, torch.floor, torch.ceil, torch.erf) + float_only_ops = {torch.floor, torch.ceil, torch.erf} real_only_ops = {torch.floor, torch.ceil, torch.erf} for op in ops: if dtypes[0] is not dtypes[1]: diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 6d4dc91ff5bd..9f3353376913 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -288,7 +288,7 @@ def test_reference_numerics(self, device, dtype, op): # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) # while NumPy computes in float16 self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=1e-4, atol=1e-3) + exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2) continue self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype) diff --git a/test/test_vmap.py b/test/test_vmap.py index 2400ad1e00ee..69e5da0aa380 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -466,6 +466,37 @@ def _assert_uses_vmap_fallback(self, vmap_args, inputs): self.assertEqual(len(wa), 2) self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) + def test_fallback_zero_dim(self): + # NB: One day we will implement a batching rule for torch.atan2. + # If/when we do, this test should be replaced to test the fallback + # path on another operator to avoid bitrot. + op = torch.atan2 + x = torch.randn(11) + y = torch.randn(11) + self._assert_uses_vmap_fallback((op,), (x, y)) + + B0, B1 = 0, 3 + x = torch.randn(B0, 11) + y = torch.randn(11) + + msg = 'The fallback path does not support vmap over dims of size 0' + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) + + x = torch.randn(B0, B1, 11) + y = torch.randn(B1, 11) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (0, None))(x, y) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op, (None, 0))(y, x) + with self.assertRaisesRegex(RuntimeError, msg): + vmap(op)(x, x) + def test_fallback_atan2(self): # NB: One day we will implement a batching rule for torch.atan2. # If/when we do, this test should be replaced to test the fallback diff --git a/third_party/fbgemm b/third_party/fbgemm index 23cb1db72b03..39d5addbff3c 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 23cb1db72b03e29984eefe58c5c99d733a85435d +Subproject commit 39d5addbff3c942e3c2c97b30c46eed167737b31 diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 033d799524fb..cd5a9b73c302 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 033d799524fb97629af5ac2f609de367472b2696 +Subproject commit cd5a9b73c3028d2496666201588111a8c8d84878 diff --git a/third_party/pybind11 b/third_party/pybind11 index 25abf7efba0b..59a2ac2745d8 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 25abf7efba0b2990f5a6dfb0a31bc65c0f2f4d17 +Subproject commit 59a2ac2745d8a57ac94c6accced73620d59fb844 diff --git a/third_party/tensorpipe b/third_party/tensorpipe index 95ff9319161f..82a114882e21 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit 95ff9319161fcdb3c674d2bb63fac3e94095b343 +Subproject commit 82a114882e21b176916e2f12a7b566af3d63df71 diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index 29b2cd1d0d0c..66c7b1c7a1ab 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -74,6 +74,7 @@ header_template_rule( "#cmakedefine01 TENSORPIPE_HAS_SHM_TRANSPORT": "", "#cmakedefine01 TENSORPIPE_HAS_CMA_CHANNEL": "", "#cmakedefine01 TENSORPIPE_HAS_CUDA_IPC_CHANNEL": "", + "#cmakedefine01 TENSORPIPE_HAS_IBV_TRANSPORT": "", "#cmakedefine01 TENSORPIPE_SUPPORTS_CUDA": "", }, ) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7fbc3e9fd4b3..8cbcab35685e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -162,7 +162,7 @@ self: grad * self.sgn() - name: acos(Tensor self) -> Tensor - self: grad * -((-self * self + 1).rsqrt()) + self: grad * -((-self * self + 1).rsqrt()).conj() - name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: grad @@ -213,7 +213,7 @@ self: grad - name: angle(Tensor self) -> Tensor - self: grad.to(self.scalar_type()) * (self*Scalar(c10::complex{0.0, 1.0})).conj() / self.abs().pow(2) + self: angle_backward(grad, self) # The four items below are necessary because TensorIterator doesn't work on # Variables (codegen does not unwrap the input Tensor for all() and any() ). @@ -230,19 +230,19 @@ self: not_implemented("all") - name: acosh(Tensor self) -> Tensor - self: grad * (self.pow(2) - 1).rsqrt() + self: grad * (self.pow(2) - 1).rsqrt().conj() - name: acosh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of acosh") - name: asinh(Tensor self) -> Tensor - self: grad * (self.pow(2) + 1).rsqrt() + self: grad * (self.pow(2) + 1).rsqrt().conj() - name: asinh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of asinh") - name: atanh(Tensor self) -> Tensor - self: grad * 1 / (1 - self.pow(2)) + self: grad * 1 / (1 - self.pow(2)).conj() - name: atanh_(Tensor(a!) self) -> Tensor(a!) self: not_implemented("inplace version of atanh") @@ -251,10 +251,10 @@ self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) - name: asin(Tensor self) -> Tensor - self: grad * (-self * self + 1).rsqrt() + self: grad * (-self * self + 1).rsqrt().conj() - name: atan(Tensor self) -> Tensor - self: grad / (self * self + 1) + self: grad / (self * self + 1).conj() - name: atan2(Tensor self, Tensor other) -> Tensor self, other: atan2_backward(grad, self, other, grad_input_mask) @@ -540,6 +540,10 @@ - name: i0(Tensor self) -> Tensor self: not_implemented("i0") +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + - name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor self: index_backward(zeros_like(self), indices, grad) indices: TensorList() @@ -610,16 +614,16 @@ self: grad * polygamma(n + 1, self) - name: log(Tensor self) -> Tensor - self: grad.div(self) + self: grad.div(self.conj()) - name: log10(Tensor self) -> Tensor - self: grad / (self * 2.3025850929940456) + self: grad / (self.conj() * 2.3025850929940456) - name: log1p(Tensor self) -> Tensor self: log1p_backward(grad, self) - name: log2(Tensor self) -> Tensor - self: grad / (self * 0.6931471805599453) + self: grad / (self.conj() * 0.6931471805599453) - name: logaddexp(Tensor self, Tensor other) -> Tensor self: grad / (1 + exp(other - self)) @@ -884,7 +888,7 @@ self: zeros_like(grad) - name: reciprocal(Tensor self) -> Tensor - self: -grad * result * result + self: -grad * (result * result).conj() - name: remainder.Scalar(Tensor self, Scalar other) -> Tensor self: grad @@ -909,7 +913,7 @@ self: zeros_like(grad) - name: rsqrt(Tensor self) -> Tensor - self: -0.5 * grad * result.pow(3) + self: -0.5 * grad * result.pow(3).conj() - name: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) self: grad.clone().scatter_(dim, index, 0) @@ -1046,7 +1050,7 @@ index: non_differentiable - name: tan(Tensor self) -> Tensor - self: grad * (1 + result.pow(2)) + self: grad * (1 + result.pow(2)).conj() - name: tanh(Tensor self) -> Tensor self: tanh_backward(grad, result) @@ -1182,7 +1186,7 @@ weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) - name: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - grad_output: embedding_dense_double_backward(grad, indices) + grad_output: embedding_dense_double_backward(grad, indices, padding_idx) indices: non_differentiable - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) @@ -1670,8 +1674,8 @@ output: grad * grad_output * (-2 * output + 1) - name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor - grad_output: tanh_backward(grad, output) - output: -2 * output * grad * grad_output + grad_output: tanh_backward(grad, output.conj()) + output: grad.conj() * (-2 * output.conj() * grad_output) # cudnn - name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 7b4b0ece8da6..661694f3d6ba 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -20,6 +20,7 @@ get_py_variable_methods, op_name, ) +import argparse import textwrap from .gen_autograd import load_aten_declarations diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index da937a4377fa..2783eb644bc6 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -302,7 +302,8 @@ def main(): parser.add_argument('autograd', metavar='AUTOGRAD', help='path to autograd directory') args = parser.parse_args() - gen_autograd(args.declarations, args.out, args.autograd) + gen_autograd(args.declarations, args.out, args.autograd, + SelectiveBuilder.get_nop_selector()) if __name__ == '__main__': diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index c82b8b298704..f3f2eb3033f2 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -871,7 +871,6 @@ def go(f: NativeFunction) -> PythonSignature: src_args: Dict[str, PythonArgument] = {a.name: PythonArgument( name=a.name, type=a.type, - cpp_type_str=a.cpp_type_str, default=None, default_init=None, ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)} @@ -922,8 +921,15 @@ def go(f: NativeFunction) -> str: lambda_args = ', '.join(lambda_arg_exprs.exprs) # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or ( + ps.method and ('requires_grad' in parser_outputs))) set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ - if ps.tensor_options_args and not has_tensor_options(f) else '' + if need_set_requires_grad else '' auto_no_gil = '' if decl['with_gil'] else 'pybind11::gil_scoped_release no_gil;' diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index aa65a68dff87..04b8d144cd79 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -165,7 +165,8 @@ 'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd', 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', - 'bmm', 'diagonal' + 'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', + 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_' } # Some operators invalidate the grad_accumulator. Let's reset it. @@ -283,9 +284,6 @@ grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); """) -CALL_DEFAULT = CodeTemplate("""\ -TypeDefault::${type_wrapper_name}(${args})""") - CALL_DISPATCH_VIA_NAMESPACE = CodeTemplate("""\ at::${api_name}(${unpacked_args})""") @@ -807,7 +805,7 @@ def emit_trace_body(declaration): def emit_body(declaration): - strategy = dispatch_strategy(declaration) + assert dispatch_strategy(declaration) == 'use_derived' arguments = declaration['arguments'] returns = declaration['returns'] @@ -864,8 +862,7 @@ def find_args_with_derivatives(differentiable_inputs): requires_derivative = ( base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and - len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0 and - strategy == 'use_derived') + len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0) if func is not None and not requires_derivative: raise RuntimeError('ERROR: derivative ignored for {} -- specified an autograd function without derivative' @@ -1149,28 +1146,20 @@ def enforce_same_tensorimpl_and_storage(env, call): def emit_call(env, tie_return_values): combined = nested_dict(env, declaration) - if strategy == 'use_derived': - # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch - # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure - # the baseType operations still dispatch to non-Variable type, even if the arguments passed - # in are now Variables. - # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. - base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) - if not modifies_arguments and not returns_void: - call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( - base_type_call=base_type_call) - - call += wrap_output(tie_return_values, 'tmp') - else: - call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( - base_type_call=base_type_call) + # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) + if not modifies_arguments and not returns_void: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call) + + call += wrap_output(tie_return_values, 'tmp') else: - args = maybe_unwrap_optional_tensors(declaration, declaration['arguments'], declaration['args']) - - call = CALL_DEFAULT.substitute(declaration, args=args) - if not modifies_arguments and not returns_void: - call = '{} = {}'.format(tie_return_values, call) - call = call + ';' + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call) call = enforce_same_tensorimpl_and_storage(env, call) return call @@ -1210,16 +1199,14 @@ def emit_increment_version(): declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration) - if strategy != 'use_type': - body.extend(unpack_args(env, declaration)) + body.extend(unpack_args(env, declaration)) if requires_derivative: body.extend(emit_check_inplace()) body.extend(setup_derivative(differentiable_inputs)) body.append(declare_returned_variables) body.append(emit_call(env, tie_return_values)) - if strategy == 'use_derived': - body.extend(emit_increment_version()) + body.extend(emit_increment_version()) if requires_derivative: # set_flags has to appear after version_counter, because rebase_history # requires that the counter is incremented before it is called diff --git a/tools/autograd/templates/TraceType.cpp b/tools/autograd/templates/TraceType.cpp index d08c1e3cc5aa..3ac52ed08edc 100644 --- a/tools/autograd/templates/TraceType.cpp +++ b/tools/autograd/templates/TraceType.cpp @@ -1,6 +1,5 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" -#include #include #include "torch/csrc/autograd/function.h" diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 079427cd97dc..ba2f99369f8d 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,7 +1,6 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" #include "torch/csrc/autograd/FunctionsManual.h" -#include #include // ${generated_comment} diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 63446d3a1316..65f5ec1c6903 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -41,7 +41,7 @@ def libtorch_generated_sources(gencode_pattern): "autograd/generated/TraceType_4.cpp", ]] -# copied from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/f99a693cd9ff7a9b5fdc71357dac66b8192786d3/aten/src/ATen/core/CMakeLists.txt jit_core_headers = [ "torch/csrc/utils/memory.h", "torch/csrc/WindowsTorchApiMacro.h", @@ -69,7 +69,7 @@ jit_core_sources = [ "torch/csrc/jit/frontend/source_range.cpp", ] -# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/torch/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/0bde610c14b92d351b968a0228df29e92442b1cc/torch/CMakeLists.txt # There are some common files used in both internal lite-interpreter and full-jit. Making a separate # list for the shared files. @@ -544,6 +544,8 @@ libtorch_python_core_sources = [ libtorch_python_distributed_core_sources = [ "torch/csrc/distributed/c10d/comm.cpp", + "torch/csrc/distributed/c10d/default_comm_hooks.cpp", + "torch/csrc/distributed/c10d/python_comm_hook.cpp", "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/reducer.cpp", ] diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index f04a648aade8..bb02407004ab 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -1,5 +1,6 @@ from tools.codegen.api.types import * import tools.codegen.api.cpp as cpp +import tools.codegen.local as local from tools.codegen.gen import pythonify_default from tools.codegen.model import * @@ -175,11 +176,6 @@ class PythonArgument: name: str type: Type - - # Consistent with 'type' for most cases, except for some TensorOptions fields - # which are hardcoded (see 'signature()' method). - cpp_type_str: str - default: Optional[str] # Used to generate the default init expr for some PythonArgParser outputs, e.g.: @@ -193,29 +189,15 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. def argument_str(self, *, method: bool = False) -> str: - name = self.name - typename = _simple_type(self.cpp_type_str) - - # [old codegen] TODO: remove this and make optional types in simple_type - # to be consistent across tensor and other types after make Tensor? be - # optional instead of undefined - if self.type.is_nullable() and '?' not in typename: - typename = f'{typename}?' + type_str = argument_type_str(self.type) # s/self/input/ outside method bindings # [old codegen] TODO: remove this? doesn't rename in codegen, it's just # for the parse string - if name == 'self' and typename == 'Tensor' and not method: + name = self.name + if name == 'self' and type_str == 'Tensor' and not method: name = 'input' - # add list size annotation - size = self.size - if size is not None: - if typename.endswith('?'): - typename = f'{typename[:-1]}[{size}]?' - else: - typename = f'{typename}[{size}]' - # add default if self.default is not None: default = { @@ -223,15 +205,9 @@ def argument_str(self, *, method: bool = False) -> str: 'c10::nullopt': 'None', '{}': 'None', }.get(self.default, self.default) - return f'{typename} {name}={default}' + return f'{type_str} {name}={default}' else: - return f'{typename} {name}' - - @property - def size(self) -> Optional[int]: - l = self.type.is_list_like() - return l.size \ - if l is not None and l.size is not None and str(l.elem) != 'bool' else None + return f'{type_str} {name}' @dataclass(frozen=True) class PythonOutArgument(PythonArgument): @@ -252,7 +228,6 @@ def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgu return PythonOutArgument( name=outputs[0].name, type=outputs[0].type, - cpp_type_str=outputs[0].cpp_type_str, default='None', default_init=None, outputs=outputs, @@ -263,7 +238,6 @@ def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgu return PythonOutArgument( name='out', type=ListType(BaseType(BaseTy.Tensor), size), - cpp_type_str='TensorList', default='None', default_init=None, outputs=outputs, @@ -368,7 +342,6 @@ def signature_str(self, *, skip_outputs: bool = False) -> str: class DispatchLambdaArgument: name: str type_str: str - cpp_type_str: str is_out_arg: bool # To pass PyObjects arguments to C++ function (via the lambda wrapper), @@ -424,28 +397,6 @@ class DispatchLambdaArgumentExprs: # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The original simple_type is derived from the 'type' field in Declaration.yaml, -# which is generated from the C++ argument type, following some seemingly -# artificial rules: -# -# Concrete C++ types are preferred in most cases, e.g.: -# 'IntArrayRef' instead of 'int[]' -# 'int64_t' instead of 'int' -# -# Constant/Reference annotation and optional field are handled specially, e.g.: -# 'ScalarType?' instead of 'c10::optional' -# 'Tensor' instead of 'const Tensor &' / 'Tensor &' -# -# TODO: This needs to be consistent with python_arg_parser - can we simplify it? -def _simple_type(cpp_type_str: str) -> str: - simple_type = cpp_type_str.replace(' &', '').replace('const ', '') - opt_match = re.match(r'c10::optional<(.+)>', simple_type) - if opt_match: - typename = opt_match.group(1) - # HACK: 'Layout?' needs to be hardcoded to 'Layout'! - simple_type = f'{typename}?' if typename != 'Layout' else 'Layout' - return simple_type - def _cpp_signature(f: NativeFunction, *, method: bool = False) -> cpp.CppSignature: return CppSignatureGroup.from_schema(f.func, method=method).signature @@ -459,6 +410,49 @@ def has_tensor_options(f: NativeFunction) -> bool: # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +def argument_type_str(t: Type) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return 'Tensor' + elif t.name == BaseTy.int: + return 'int64_t' + elif t.name == BaseTy.float: + return 'double' + elif t.name == BaseTy.str: + return 'std::string' + elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, + BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, + BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, + BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + elem = argument_type_str(t.elem) + if elem == 'Layout': + # TODO: fix this special case in PythonArgParser? + return 'Layout' + else: + return f'{elem}?' + + elif isinstance(t, ListType): + if str(t.elem) == 'bool': + assert t.size is not None + return f'std::array' + elif str(t.elem) == 'int': + return f'IntArrayRef[{t.size}]' if t.size is not None else 'IntArrayRef' + elif str(t.elem) == 'Tensor': + return f'TensorList[{t.size}]' if t.size is not None else 'TensorList' + elif str(t.elem) == 'Tensor?': + # TODO: clone the old codegen behavior but does it make sense? + return 'TensorList?' + elif str(t.elem) == 'Dimname': + return f'DimnameList[{t.size}]' if t.size is not None else 'DimnameList' + elem = argument_type_str(t.elem) + return f'ArrayRef<{elem}>' + + raise RuntimeError(f'unrecognized type {repr(t)}') + def argument(cpp_arg: CppArgument) -> PythonArgument: a = cpp_arg.argument if not isinstance(a, Argument): @@ -468,7 +462,6 @@ def argument(cpp_arg: CppArgument) -> PythonArgument: return PythonArgument( name=a.name, type=a.type, - cpp_type_str=cpp_arg.type, # TODO: directly translate a.default to python default default=str(pythonify_default(cpp.default_expr(a.default, a.type))) if a.default is not None else None, @@ -515,55 +508,37 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) name: str = cpp.name(f.func) - has_options_arg = has_tensor_options(f) - - is_like_function = name.endswith('_like') or f.category_override == 'like' - is_new_function = name.startswith('new_') or f.category_override == 'new' - is_factory_function = has_tensor_return and not has_tensor_input_arg \ - or f.category_override == 'factory' - is_like_or_new_function_with_options = \ - (is_like_function or is_new_function) and has_options_arg + is_factory_function = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg) + is_like_or_new_function = f.category_override in ('new', 'like') or name.startswith('new_') or name.endswith('_like') tensor_options_args: List[PythonArgument] = [] - if is_factory_function or has_options_arg: + if is_factory_function or is_like_or_new_function: tensor_options_args.append(PythonArgument( name='dtype', - cpp_type_str='const ScalarType &', type=BaseType(BaseTy.ScalarType), default=_dtype_default_type_hack(name), - default_init='self.scalar_type()' - if is_like_or_new_function_with_options else None, + default_init='self.scalar_type()' if is_like_or_new_function else None, )) - - if is_factory_function or is_like_or_new_function_with_options: tensor_options_args.append(PythonArgument( name='layout', - cpp_type_str='c10::optional', - type=BaseType(BaseTy.Layout), + type=OptionalType(BaseType(BaseTy.Layout)), default='torch.strided', - default_init='layout_from_backend(self.options().backend())' - if is_like_or_new_function_with_options else None, + default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( name='device', - cpp_type_str='const Device &', type=BaseType(BaseTy.Device), default='None', - default_init='self.device()' - if is_like_or_new_function_with_options else None, + default_init='self.device()' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( name='pin_memory', - cpp_type_str='bool', type=BaseType(BaseTy.bool), default='False', default_init=None, )) - - if has_tensor_return and (is_factory_function or is_like_function or is_new_function): tensor_options_args.append(PythonArgument( name='requires_grad', - cpp_type_str='bool', type=BaseType(BaseTy.bool), default='False', default_init=None, @@ -660,7 +635,6 @@ def dispatch_lambda_arg(cpp_arg: CppArgument) -> DispatchLambdaArgument: return DispatchLambdaArgument( name=cpp_arg.name, type_str=type_str, - cpp_type_str=cpp_arg.type, is_out_arg=is_out_arg, ) @@ -750,107 +724,91 @@ def cpp_dispatch_exprs(f: NativeFunction, method: bool, *, # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# TODO: should emit these unpack methods directly from Type to avoid -# indirect translation via cpp_type_str. -UNPACK_METHODS = { - 'const Tensor &': 'tensor', - 'Tensor &': 'tensor', - 'Stream': 'stream', - 'c10::optional': 'optionalTensor', - 'const c10::optional&': 'optionalTensor', - 'c10::optional': 'generator', - 'Storage': 'storage', - 'Storage &': 'storage', - 'const ScalarType &': 'scalartype', - 'const Device &': 'device', - 'c10::optional': 'toDimnameListOptional', - 'c10::optional': 'scalartypeOptional', - 'c10::optional': 'layoutOptional', - 'c10::optional': 'memoryformatOptional', - 'c10::optional': 'scalarOptional', - 'c10::optional': 'intlistOptional', - 'c10::optional': 'toInt64Optional', - 'c10::optional': 'toBoolOptional', - 'c10::optional': 'toDoubleOptional', - 'c10::optional>': 'doublelistOptional', - 'ArrayRef': 'doublelist', - 'IntArrayRef': 'intlist', - 'Scalar': 'scalar', - 'ScalarType': 'scalartype', - 'Dimname': 'dimname', - 'DimnameList': 'dimnamelist', - 'TensorList': 'tensorlist', - 'int64_t': 'toInt64', - 'bool': 'toBool', - 'double': 'toDouble', - 'std::string': 'string', - 'c10::optional': 'stringOptional', -} - -UNPACK_WITH_SIZE_METHODS = { - 'TensorList': 'tensorlist_n<{}>', - 'DimnameList': 'dimnamelist', - 'IntArrayRef': 'intlist', - 'c10::optional': 'intlistOptional', -} - -UNPACK_WITH_DEFAULT_METHODS = { - 'const ScalarType &': 'scalartypeWithDefault', - 'const Device &': 'deviceWithDefault', - 'c10::optional': 'layoutWithDefault', -} +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method(t: Type, has_default: bool) -> str: + if has_default and str(t) not in ('ScalarType', 'Device', 'Layout?'): + raise RuntimeError(f'type \'{t}\' does not supported unpacking with default') + + if isinstance(t, BaseType): + if t.name in [BaseTy.Tensor, BaseTy.Stream, BaseTy.Storage, + BaseTy.Scalar, BaseTy.Dimname]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return 'scalartypeWithDefault' if has_default else 'scalartype' + elif t.name == BaseTy.Device: + return 'deviceWithDefault' if has_default else 'device' + elif t.name == BaseTy.int: + return 'toInt64' + elif t.name == BaseTy.bool: + return 'toBool' + elif t.name == BaseTy.float: + return 'toDouble' + elif t.name == BaseTy.str: + return 'string' + + elif isinstance(t, OptionalType): + if str(t.elem) == 'Tensor': + if local.use_c10_dispatcher().dispatcher_uses_new_style(): + return 'optionalTensor' + else: + return 'tensor' + + elif isinstance(t.elem, BaseType): + if t.elem.name in [BaseTy.ScalarType, BaseTy.Scalar, + BaseTy.int, BaseTy.bool, + BaseTy.float, BaseTy.str]: + # Regular cases: append 'Optional' to elem's unpacking method + return arg_parser_unpack_method(t.elem, False) + 'Optional' + elif t.elem.name == BaseTy.MemoryFormat: + return 'memoryformatOptional' + elif t.elem.name == BaseTy.Generator: + return 'generator' + elif t.elem.name == BaseTy.Layout: + return 'layoutWithDefault' if has_default else 'layoutOptional' + + elif isinstance(t.elem, ListType): + if str(t.elem.elem) == 'int': + # accept definite size + return 'intlistOptional' + elif str(t.elem) == 'float[]': + return 'doublelistOptional' + elif str(t.elem) == 'Dimname[]': + return 'toDimnameListOptional' + + elif isinstance(t, ListType): + if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?': + # accept and use definite size + if t.size is not None: + return f'tensorlist_n<{t.size}>' + else: + return 'tensorlist' + elif str(t.elem) == 'Dimname': + # accept definite size + return 'dimnamelist' + elif str(t.elem) == 'int': + # accept definite size + return 'intlist' + elif str(t) == 'float[]': + return 'doublelist' + + raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser') # Return RHS expression for python argument using PythonArgParser output. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' def arg_parser_output_expr( - arg_index: int, a: PythonArgument, la: Optional[DispatchLambdaArgument] + arg_index: int, a: PythonArgument ) -> PythonArgParserOutputExpr: - # The same python signature (and python schema string) is usually - # associated with two aten C++ functions: the base version and the - # out-place variant. Usually the two functions have the same set of - # arguments - of course, except for the output arguments. But in some - # cases they might have slightly different C++ argument types - - # affected by the 'use_c10_dispatcher' state. - # - # More specially, 'Tensor?' type can be translated into - # either 'const c10::optional&' or 'const Tensor &'. - # Unfortunately, this difference can affect how we should access arg - # parser output. The former expects '_r.optionalTensor(i)' while the - # latter expects '_r.tensor(i)'. - # - # Because of this subtle difference, we cannot solely use the shared - # python signature to determine the RHS expr for both C++ variants. - # We could create and use each C++ variant's own python signature, - # but we have to fix the argument index difference between the two - # python signatures like the old codegen does - and it feels wrong as - # technically there is only one shared python signature! - # - # So here we pass in the lambda wrapper's argument and use it to - # decide what PythonArgParser unpack method to use. - # - # TODO: this seems too complicated - maybe we can simplify after full - # c10 dispatch migration? - typename = la.cpp_type_str \ - if a.name != 'out' and la is not None else a.cpp_type_str - - if a.default_init is not None: - # Note: only introduced in tensor_options_args - if typename not in UNPACK_WITH_DEFAULT_METHODS: - raise RuntimeError( - f'type \'{typename}\' is not supported in default_init') - unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename] - expr = f'_r.{unpack_with_default}({arg_index}, {a.default_init})' - elif a.size is not None: - if typename not in UNPACK_WITH_SIZE_METHODS: - raise RuntimeError( - f'type \'{typename}\' with definite size ({a.size}) is not supported') - unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(a.size) - expr = f'_r.{unpack_with_size}({arg_index})' - else: - unpack = UNPACK_METHODS.get(typename) - if unpack is None: - raise RuntimeError(f'type \'{typename}\' is not supported') - expr = f'_r.{unpack}({arg_index})' + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method(a.type, has_default) + default = f', {a.default_init}' if has_default else '' + expr = f'_r.{unpack_method}({arg_index}{default})' return PythonArgParserOutputExpr( name=a.name, @@ -863,17 +821,14 @@ def arg_parser_output_expr( def arg_parser_output_exprs( ps: PythonSignature, f: NativeFunction, *, method: bool ) -> Dict[str, PythonArgParserOutputExpr]: - lambda_args = dispatch_lambda_args(ps, f, method=method) - lambda_args_map = dict(map(lambda a: (a.name, a), lambda_args)) - return {e.name: e for i, a in enumerate(ps.arguments()) - for e in (arg_parser_output_expr(i, a, lambda_args_map.get(a.name)), )} + for e in (arg_parser_output_expr(i, a), )} -# argument name to 'simple_type' for scattered tensor options fields +# argument name to type for scattered tensor options fields TENSOR_OPTIONS_FIELDS = { 'dtype': 'ScalarType', 'device': 'Device', - 'layout': 'Layout', + 'layout': 'Layout?', 'pin_memory': 'bool', 'requires_grad': 'bool', } @@ -909,7 +864,7 @@ def dispatch_lambda_exprs( ]) for i, out_arg in enumerate(a.outputs): lambda_args_exprs[out_arg.name] = f'out[{i}]' - elif a.cpp_type_str == 'c10::optional': + elif str(a.type) == 'Dimname[]?': # [old codegen] # TODO: make this part of something more general, or get rid of it. # optional> are special. The PythonArgParser returns an @@ -937,9 +892,9 @@ def dispatch_lambda_exprs( if a.name not in TENSOR_OPTIONS_FIELDS: raise RuntimeError( f'{f.func}: unrecognized tensor options field \'{a.name}\' in python binding arguments') - if _simple_type(a.cpp_type_str) != TENSOR_OPTIONS_FIELDS.get(a.name): + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): raise RuntimeError( - f'{f.func}: unrecognized type \'{_simple_type(a.cpp_type_str)}\' for tensor options field \'{a.name}\'') + f'{f.func}: unrecognized type \'{str(a.type)}\' for tensor options field \'{a.name}\'') if not all(map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())): raise RuntimeError( f'{f.func}: incomplete tensor options args: {tensor_options_args_names}') diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index e495244a183e..433f3cf5fd67 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -341,12 +341,8 @@ class NativeExpr: class NativeArgument: type: str name: str - # Native function arguments have defaults for some reasons (e.g., - # the function prototypes in CPUType.h are defaulted). There isn't - # really any good reason to do this, as these functions are only - # ever called from a context where all defaulted arguments are - # guaranteed to be given explicitly. - # TODO: Remove this + # Native function arguments have defaults to make it a little + # easier to call them directly to bypass dispatch. default: Optional[str] argument: Union[Argument, TensorOptionsArguments] diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index af358d9d1b7c..134f7163518b 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -158,11 +158,10 @@ def cpp_string(s: str) -> str: # Dispatch keywords in native_functions.yaml that support all backends. KEYWORD_ALL_BACKENDS = ('DefaultBackend', 'Math') -# Generates {dispatch}Type.cpp and {dispatch}Type.h (e.g., CPUType.cpp -# and CPUType.h). This function is also reused to implement per-operator -# registration. It also generates TypeDefault.cpp and TypeDefault.h when -# dispatch target is for all backends (dispatch is None or dispatch in -# KEYWORD_ALL_BACKENDS). +# Generates {dispatch}Type.cpp (e.g., CPUType.cpp). This function is also +# reused to implement per-operator registration. It also generates +# TypeDefault.cpp when dispatch target is for all backends (dispatch is None or +# dispatch in KEYWORD_ALL_BACKENDS). # # {dispatch}Type.cpp # - The primary function of this file is to register all of the @@ -179,36 +178,29 @@ def cpp_string(s: str) -> str: # (as would be the case if you directly registered native:: # functions). # -# {dispatch}Type.h -# - In principle, this file shouldn't exist at all; historically, -# it existed so that we could directly access these functions -# outside of the registration API for the implementation of -# static dispatch. Should be deleted now! -# # This function is also used for a secondary purpose: the registration # logic is also reused to implement per-operator registration. def compute_type_method( dispatch: Optional[str], *, + # TODO: Give more precise type Union[Literal[Target.DEFINITION, + # Target.REGISTRATION]]; requires Literal from typing_extensions + # which we don't have a dep for yet. target: Target, # Selector object to determine which operators to generate # registration code for. - selector: SelectiveBuilder, - # Only valid for generating registrations. If True, only generate - # def() invocations (for schema registration); do not generate - # any impl() invocations for, e.g., catch-all kernels - def_only: bool = False + selector: SelectiveBuilder ) -> Callable[[NativeFunction], Optional[str]]: - if def_only: - assert target is Target.REGISTRATION and dispatch is None + if dispatch is None: + assert target is Target.REGISTRATION @with_native_function def func(f: NativeFunction) -> Optional[str]: + # Has to be here as mypy won't transfer asserts into closures + assert target is not Target.DECLARATION + if dispatch is not None: - if f.dispatch is None or dispatch not in f.dispatch: - return None - else: - if f.dispatch is not None and target is not Target.REGISTRATION: + if dispatch not in f.dispatch: return None op_name = f"aten::{f.func.name}" @@ -219,24 +211,18 @@ def func(f: NativeFunction) -> Optional[str]: returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) args_str = ', '.join(map(str, args)) - dispatch_to_all_backends = dispatch is None or dispatch in KEYWORD_ALL_BACKENDS + dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS - if target is Target.DECLARATION: - return f"{returns_type} {name}({args_str});" - elif target is Target.DEFINITION: - if f.dispatch is None: - cpp_name = cpp.name(f.func) - impl_name = f"at::native::{cpp_name}" - else: - assert dispatch is not None - impl_name = f"at::native::{f.dispatch[dispatch]}" + if target is Target.DEFINITION: + assert dispatch is not None + impl_name = f"at::native::{f.dispatch[dispatch]}" args_exprs_str = ', '.join(a.name for a in args) return_kw = " return " cuda_guard = "" - if dispatch_to_all_backends or 'CUDA' in dispatch or 'Vulkan' == dispatch: # type: ignore + if dispatch_to_all_backends or 'CUDA' in dispatch: self_args = (a for a in f.func.arguments if a.name == "self") # There is precedence for which argument we use to do @@ -261,7 +247,7 @@ def func(f: NativeFunction) -> Optional[str]: # TODO: There is probably a simpler version of this that # works just as well. - if f.device_guard and (dispatch_to_all_backends or 'Vulkan' == dispatch) and has_tensor_options: + if f.device_guard and dispatch_to_all_backends and has_tensor_options: cuda_guard = cuda_guard_from_tensor_options elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options: cuda_guard = f"""\ @@ -284,20 +270,18 @@ def func(f: NativeFunction) -> Optional[str]: """ elif target is Target.REGISTRATION: - dispatcher_sig = DispatcherSignature.from_schema(f.func) - - if dispatch_to_all_backends: - type_name = f'TypeDefault::{name}' + if dispatch is None: + return f'm.def({cpp_string(str(f.func))});\n' + elif f.manual_kernel_registration: + return None else: - type_name = f'{dispatch}Type::{name}' + if dispatch_to_all_backends: + type_name = f'TypeDefault::{name}' + else: + type_name = f'{dispatch}Type::{name}' - # def registration only happens in TypeDefault - def_registration = "" - if dispatch is None: - def_registration = f'm.def({cpp_string(str(f.func))});\n' + dispatcher_sig = DispatcherSignature.from_schema(f.func) - impl_registration = "" - if not def_only and not f.manual_kernel_registration and (dispatch is not None or f.dispatch is None): # Figure out which signature the function is if local.use_c10_dispatcher() is UseC10Dispatcher.full: payload = f"TORCH_FN({type_name})" @@ -321,9 +305,7 @@ def func(f: NativeFunction) -> Optional[str]: if dispatch is not None: payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n" - impl_registration = f'm.impl("{f.func.name}",\n{payload});\n' - - return f"{def_registration}{impl_registration}" + return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(target) @@ -439,10 +421,7 @@ def compute_aten_op(f: NativeFunction) -> str: # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function def compute_native_function_declaration(f: NativeFunction) -> List[str]: - if f.dispatch is None: - ns = [cpp.name(f.func)] - else: - ns = list(f.dispatch.values()) + ns = list(f.dispatch.values()) rs = [] # Sometimes a function name shows up multiple times; only generate @@ -763,8 +742,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object: is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \ and Variant.method not in f.variants - # Having only Math in dispatch section is equivalent to no dispatch section. - is_abstract = f.dispatch is not None and set(f.dispatch.keys()) != set({'Math'}) # type ignore + is_abstract = f.dispatch.keys() != {'Math'} return OrderedDict([ ('name', cpp.name(f.func)), @@ -803,7 +781,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('device_guard', f.device_guard), ('with_gil', False), ('deprecated', False), - ('has_math_kernel', f.dispatch is not None and 'Math' in f.dispatch), + ('has_math_kernel', 'Math' in f.dispatch), ]) @with_native_function @@ -814,8 +792,9 @@ def compute_registration_declarations(f: NativeFunction) -> str: args_str = ', '.join(map(str, args)) comment_data : Dict[str, str] = { 'schema': f'aten::{f.func}', - 'dispatch': str(f.dispatch is not None), - 'default': str(f.dispatch is not None and any(k in f.dispatch for k in KEYWORD_ALL_BACKENDS)) + # TODO: What exactly is the semantics of the 'dispatch' field? + 'dispatch': str(f.dispatch.keys() != {'Math'}), + 'default': str(any(k in f.dispatch for k in KEYWORD_ALL_BACKENDS)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} """ @@ -931,11 +910,6 @@ def main() -> None: '--rocm', action='store_true', help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') - # TODO: remove this, we should just unconditionally generate Vulkan - parser.add_argument( - '--vulkan', - action='store_true', - help='Generate Vulkan backend functions') # TODO: --op_registration_whitelist will be removed when all call-sites # for gen.py are moved over to using the operator YAML file for mobile # custom build. @@ -1005,48 +979,36 @@ def make_file_manager(install_dir: str) -> FileManager: cuda_fm = make_file_manager(options.install_dir) extra_cuda_headers = '''\ -#include #include #include #include ''' if options.rocm: extra_cuda_headers = '''\ -#include #include #include #include ''' - backends = ["CPU", "SparseCPU", "MkldnnCPU", "CUDA", "SparseCUDA", "QuantizedCPU", "QuantizedCUDA"] - if options.vulkan: - backends.append("Vulkan") + backends = [ + "CPU", + "SparseCPU", + "MkldnnCPU", + "CUDA", + "SparseCUDA", + "QuantizedCPU", + "QuantizedCUDA", + ] if options.backend_whitelist: backends = [b for b in backends if b in options.backend_whitelist] for dispatch in backends: h_template = 'TypeDerived.h' cpp_template = 'TypeDerived.cpp' - # TODO: delete this special case - if 'Sparse' in dispatch: - cpp_template = 'SparseTypeDerived.cpp' fm = cuda_fm if 'CUDA' in dispatch else cpu_fm - fm.write_with_template(f'{dispatch}Type.h', h_template, lambda: { - 'Type': f'{dispatch}Type', - 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', # TODO: remove this - 'type_derived_method_declarations': list(mapMaybe( - compute_type_method(dispatch, target=Target.DECLARATION, selector=selector), - native_functions - )), - }) fm.write_with_template(f'{dispatch}Type.cpp', cpp_template, lambda: { 'Type': f'{dispatch}Type', - # TODO: remove this 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', - # TODO: remove this - 'storage_tensor_headers': '#include ', - # TODO: remove this - 'Generator': 'CUDAGeneratorImpl' if 'CUDA' in dispatch else 'CPUGeneratorImpl', 'legacy_th_headers': '#include ' if dispatch == "CPU" else '#include ' if dispatch == "CUDA" else @@ -1064,23 +1026,13 @@ def make_file_manager(install_dir: str) -> FileManager: }) del fm - cpu_fm.write('TypeDefault.h', lambda: { - 'type_method_declarations': - list(mapMaybe( - compute_type_method(None, target=Target.DECLARATION, selector=selector), - native_functions)) + - list(mapMaybe( - compute_type_method('Math', target=Target.DECLARATION, selector=selector), - native_functions)) + - list(mapMaybe( - compute_type_method('DefaultBackend', target=Target.DECLARATION, selector=selector), - native_functions)), - }) + schema_selector = selector + if options.force_schema_registration: + schema_selector = SelectiveBuilder.get_nop_selector() + + # TODO: split this file into separate files cpu_fm.write('TypeDefault.cpp', lambda: { 'type_method_definitions': - list(mapMaybe( - compute_type_method(None, target=Target.DEFINITION, selector=selector), - native_functions)) + list(mapMaybe( compute_type_method('Math', target=Target.DEFINITION, selector=selector), native_functions)) + @@ -1089,10 +1041,12 @@ def make_file_manager(install_dir: str) -> FileManager: native_functions)), 'function_registrations': list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, selector=selector), - native_functions)) + list(mapMaybe( - compute_type_method('Math', target=Target.REGISTRATION, selector=selector), - native_functions)), + compute_type_method(None, target=Target.REGISTRATION, selector=schema_selector), + native_functions)), + + 'math_function_registrations': list(mapMaybe( + compute_type_method('Math', target=Target.REGISTRATION, selector=selector), + native_functions)), 'default_backend_function_registrations': list(mapMaybe( compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector), @@ -1123,16 +1077,6 @@ def make_file_manager(install_dir: str) -> FileManager: list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)), }) - if options.force_schema_registration: - def computeSchemaRegister() -> Dict[str, object]: - schema_registrations = list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, selector=SelectiveBuilder.get_nop_selector(), def_only=True), - native_functions)) - return { - 'schema_registrations': schema_registrations, - } - cpu_fm.write('SchemaRegister.cpp', computeSchemaRegister) - cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) cpu_fm.write('RegistrationDeclarations.h', lambda: { 'registration_declarations': [compute_registration_declarations(f) for f in native_functions], diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 56605b7130db..95cb7438a814 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -98,13 +98,15 @@ class NativeFunction: # registrations don't participate in codegen-based selective build! manual_kernel_registration: bool - # Distinguish between a missing dispatch dict (historically, this - # means to register a catch-all kernel) and a present but empty - # dispatch dict (this means register nothing; arguably, this should - # subsume manual_kernel_registration). + # A mapping of dispatch keys to names of functions implementing + # them. In native_functions.yaml, the dispatch entry is optional; in that + # case, that is equivalent to having written: + # + # dispatch: + # Math: $operator_name # # TODO: str key could be replaced with more explicit enum - dispatch: Optional[Dict[str, str]] + dispatch: Dict[str, str] # The location in the YAML file were this native function entry was # defined. This is for conveniently reporting error messages! @@ -162,9 +164,8 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': raw_dispatch = e.pop('dispatch', None) assert raw_dispatch is None or isinstance(raw_dispatch, dict), e - dispatch: Optional[Dict[str, str]] = None + dispatch: Dict[str, str] = {} if raw_dispatch is not None: - dispatch = {} for ks, v in raw_dispatch.items(): if ks == '__line__': continue # not worth tracking line numbers for dispatch entries @@ -172,9 +173,14 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': assert isinstance(v, str), e for k in ks.split(","): dispatch[k.strip()] = v + else: + from tools.codegen.api import cpp + dispatch['Math'] = cpp.name(func) - # Throws if both DefaultBackend and Math are provided - assert not (dispatch is not None and 'DefaultBackend' in dispatch and 'Math' in dispatch) + assert not ('DefaultBackend' in dispatch and 'Math' in dispatch), \ + "cannot specify both DefaultBackend and Math on a single kernel; each " \ + "strictly subsumes the other. If you wanted to provide an explicit autograd " \ + "implementation, specify DefaultBackend; otherwise specify Math only" e.pop('__line__') assert not e, f"leftover entries: {e}" diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index b5ea62ff29bb..8129f38eb0ef 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -2,6 +2,7 @@ import os import subprocess from pathlib import Path +from distutils.util import strtobool def get_sha(): try: @@ -27,7 +28,7 @@ def get_torch_version(sha=None): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate torch/version.py from build and environment metadata.") - parser.add_argument("--is_debug", type=bool, help="Whether this build is debug mode or not.") + parser.add_argument("--is_debug", type=strtobool, help="Whether this build is debug mode or not.") parser.add_argument("--cuda_version", type=str) parser.add_argument("--hip_version", type=str) @@ -47,7 +48,7 @@ def get_torch_version(sha=None): # NB: This is not 100% accurate, because you could have built the # library code with DEBUG, but csrc without DEBUG (in which case # this would claim to be a release build when it's not.) - f.write("debug = {}\n".format(repr(args.is_debug))) + f.write("debug = {}\n".format(repr(bool(args.is_debug)))) f.write("cuda = {}\n".format(repr(args.cuda_version))) f.write("git_version = {}\n".format(repr(sha))) f.write("hip = {}\n".format(repr(args.hip_version))) diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index 8d1fb00fc8d2..f2896fac7f22 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -535,7 +535,8 @@ def main(): parser.add_argument('template_path', metavar='TEMPLATE_PATH', help='path to templates directory') args = parser.parse_args() - gen_unboxing_wrappers(args.declarations, args.out, args.template_path) + gen_unboxing_wrappers(args.declarations, args.out, args.template_path, + SelectiveBuilder.get_nop_selector()) if __name__ == '__main__': diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index b603c8386a2c..64c746e7eff2 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -599,6 +599,20 @@ class _CudaDeviceProperties: is_integrated: _int is_multi_gpu_board: _int +# Defined in torch/csrc/cuda/python_comm.cpp +def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... +def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... +def _broadcast_coalesced( + tensors: List[Tensor], + devices: List[_int], + buffer_size: _int +) -> List[List[Tensor]]: ... + +def _scatter(tensor: Tensor, devices: List[_int], chunk_sizes: Optional[List[_int]], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... +def _scatter_out(tensor: Tensor, out_tensors: List[Tensor], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... +def _gather(tensors: List[Tensor], dim: _int, destination_index: Optional[_int]) -> Tensor: ... +def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... + # Defined in torch/csrc/cuda/Stream.cpp class _CudaStreamBase: _cdata: _int diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4396155c73ea..a1c800debb59 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj): the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: - annotations.append(torch._C.TensorType.get()) + annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, fields, annotations diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 017465b9bdff..9339d805c1b9 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1542,6 +1542,20 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.i0` """) +add_docstr_all('igamma', + r""" +igamma(other) -> Tensor + +See :func:`torch.igamma` +""") + +add_docstr_all('igamma_', + r""" +igamma_(other) -> Tensor + +In-place version of :meth:`~Tensor.igamma` +""") + add_docstr_all('indices', r""" indices() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9c220dc259a0..e3fc7acfa160 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1818,6 +1818,40 @@ def merge_dicts(*dicts): Alias for :func:`torch.clamp`. """.format(**common_args)) +add_docstr(torch.column_stack, + r""" +column_stack(tensors, *, out=None) -> Tensor + +Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + +Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` +in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + +""".format(**common_args)) + add_docstr(torch.complex, r""" complex(real, imag, *, out=None) -> Tensor @@ -3316,6 +3350,47 @@ def merge_dicts(*dicts): """.format(**common_args)) +add_docstr(torch.igamma, + r""" +igamma(input, other, *, out=None) -> Tensor + +Computes the regularized lower incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.lgamma` for a related function. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.igamma(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([0.3528, 0.5665, 0.7350]) + +""".format(**common_args)) + add_docstr(torch.index_select, r""" index_select(input, dim, index, *, out=None) -> Tensor @@ -6737,6 +6812,12 @@ def merge_dicts(*dicts): torch.uint8 """) +add_docstr(torch.row_stack, + r""" +row_stack(tensors, *, out=None) -> Tensor + +Alias of :func:`torch.vstack`. +""".format(**common_args)) add_docstr(torch.round, r""" diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index f7e48c3b682d..92e8a93c284e 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -59,7 +59,7 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala PyTypeObject* getPyTypeObject( const at::Storage& storage, - const caffe2::TypeMeta& dtype) { + const caffe2::TypeMeta dtype) { at::ScalarType scalarType = at::typeMetaToScalarType(dtype); auto attype = &at::getDeprecatedTypeProperties( at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())), @@ -106,7 +106,7 @@ THPLayout* getTHPLayout(at::Layout layout) { PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { auto type = getPyTypeObject(storage, data_type); auto obj = THPObjectPtr(type->tp_alloc(type, 0)); if (!obj) throw python_error(); diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 0877fb317cb3..d93d0e3b5cf5 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -29,7 +30,7 @@ void registerLayoutObject(THPLayout *thp_layout, at::Layout layout); PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); at::Storage createStorage(PyObject* obj); bool isStorage(PyObject* obj); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index d09b3428e4f1..a5df6329030d 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -85,9 +85,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) THPObjectPtr types(PySequence_Fast(arg, "expected a sequence")); if (!types) return nullptr; - int num_classes = PySequence_Fast_GET_SIZE(types.get()); + auto num_classes = PySequence_Fast_GET_SIZE(types.get()); names.reserve(names.size() + num_classes); - for (size_t i = 0; i < num_classes; i++) { + for (Py_ssize_t i = 0; i < num_classes; i++) { PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i); THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject"); PyTypeObject* type = (PyTypeObject*)obj; @@ -864,7 +864,30 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False)); #endif - auto defaultGenerator = at::detail::getDefaultCPUGenerator(); +// See note [Pybind11 ABI constants] +#define SET_STR_DEFINE(name) \ + ASSERT_TRUE(set_module_attr("_" # name, THPUtils_packString(name))) + +#ifdef PYBIND11_COMPILER_TYPE + SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); +#endif + +#ifdef PYBIND11_STDLIB + SET_STR_DEFINE(PYBIND11_STDLIB); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None)); +#endif + +#ifdef PYBIND11_BUILD_ABI + SET_STR_DEFINE(PYBIND11_BUILD_ABI); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None)); +#endif +#undef SET_STR_DEFINE + + const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); // This reference is meant to be given away, so no need to incref here. ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 5ce90dcc972e..c0bae62510e6 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -28,6 +28,14 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +inline Tensor tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve(self, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve_out(result, self, other, dims); +} + } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -53,4 +61,22 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`. +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve +/// +/// Example: +/// ``` +/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4}); +/// auto b = torch::randn(2*3, 4); +/// auto x = torch::linalg::tensorsolve(a, b); +/// ``` +inline Tensor tensorsolve(const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve(input, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve_out(result, input, other, dims); +} + }} // torch::linalg diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 689d66c6f405..d3752bce04cc 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -95,6 +95,14 @@ Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) { return gradient_result; } +Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { + if (!self.is_complex() && gradient_result.is_complex()) { + // R -> C + return at::real(gradient_result); + } + return gradient_result; +} + Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) { if (keepdim) { return output; @@ -177,16 +185,18 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const optional & } Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) { - double exponent = exponent_.toDouble(); + auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble(); if (exponent == 0.0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - return grad * exponent * self.pow(exponent - 1); + auto out = grad * (exponent * self.pow(exponent - 1)).conj(); + return handle_r_to_c(self, out); } } Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { - return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1)); + auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj()); + return handle_r_to_c(self, out); } // Caveats: @@ -198,18 +208,46 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone // d(a^b)/db = 0 for a > 0 and b -> +0. // Currently, tensorflow agrees with us. Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { - return grad * at::where(at::logical_and(self == 0, exponent >= 0), + Tensor cond; + if (exponent.is_complex()) { + auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); + cond = at::logical_and(self == 0, is_real_exp); + } else { + cond = at::logical_and(self == 0, exponent >= 0); + } + auto out = grad * at::where(cond, at::zeros({}, grad.options()), - result * self.log()); + (result * self.log()).conj()); + return handle_r_to_c(exponent, out); } Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { - if (base.toDouble() == 0) { - return grad * at::where(exponent >= 0, + auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble(); + auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); }; + if (base_ == 0.0) { + auto cond = [](auto exp) { + if (exp.is_complex()) { + return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); + } else { + return exp >=0; + } + }; + auto out = grad * at::where(cond(exponent), at::zeros({}, grad.options()), - result * std::log(base.toDouble())); + grad_lambda(result, base_)); + return handle_r_to_c(exponent, out); } else { - return grad * result * std::log(base.toDouble()); + auto out = grad * grad_lambda(result, base_); + return handle_r_to_c(exponent, out); + } +} + +Tensor angle_backward(Tensor grad, const Tensor& self) { + if (self.is_complex()) { + return at::where(self == 0.0, at::zeros({}, self.options()), + grad * self / self.abs().pow(2) * Scalar(c10::complex{0.0, 1.0})); + } else { + return at::zeros_like(self, at::MemoryFormat::Preserve); } } @@ -226,35 +264,23 @@ Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result))); } else { - return at::zeros_like(grad, at::MemoryFormat::Preserve); + return at::zeros_like(self, at::MemoryFormat::Preserve); } } Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { - auto result = grad * other.conj(); - if (!at::isComplexType(self_st) && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + auto out = grad * other.conj(); + return handle_r_to_c(self_st, out); } Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { auto result = grad / other.conj(); - if (!at::isComplexType(self_st) && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + return handle_r_to_c(self_st, result); } Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { auto result = -grad * ((self / other) / other).conj(); - if (!other.is_complex() && result.is_complex()) { - // R -> C - result = at::real(result); - } - return result; + return handle_r_to_c(other, result); } Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { @@ -2647,7 +2673,7 @@ Tensor log1p_backward(const Tensor& grad, const Tensor& self) { "Use a different mathematical operation which preserves sparsity of gradients, ", "or report a bug if you think this is an error."); } - return grad / (self + 1); + return grad / (self + 1).conj(); } Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) { @@ -2668,16 +2694,18 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { return at::constant_pad_nd(grad, negated_pad, 0); } -Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) { - // since first backward takes care of padding_idx - // and scaling by frequency, we don't need to worry - // about it here. +Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { + // since first backward takes care of scaling by frequency, + // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); // reshape gradient as per the shape of indices auto size = indices.sizes().vec(); size.push_back(-1); + if (padding_idx >= 0) { + gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } return gg_weight.view(size); } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index ffd65081f678..46f26610c127 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -44,6 +44,7 @@ at::Tensor pow_backward(at::Tensor grad, const at::Tensor & self, const at::Scal at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at::Tensor & exponent); at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result); at::Tensor pow_backward_exponent(at::Tensor grad, const at::Scalar & base, const at::Tensor& exponent, at::Tensor result); +at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self); at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st); at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other); @@ -117,7 +118,7 @@ at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, cons at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet); at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape); -at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices); +at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); diff --git a/torch/csrc/autograd/anomaly_mode.cpp b/torch/csrc/autograd/anomaly_mode.cpp index bbb76fba656f..16783fdbef15 100644 --- a/torch/csrc/autograd/anomaly_mode.cpp +++ b/torch/csrc/autograd/anomaly_mode.cpp @@ -1,9 +1,48 @@ +#include +#include #include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { bool AnomalyMode::_enabled = false; AnomalyMetadata::~AnomalyMetadata() = default; -}} +void AnomalyMetadata::store_stack() { + traceback_ = c10::get_backtrace(/* frames_to_skip */ 1); +} + +void AnomalyMetadata::print_stack(const std::string& current_node_name) { + TORCH_WARN( + "Error detected in ", + current_node_name, + ". ", + "Traceback of forward call that caused the error:\n", + traceback_); + + auto& cur_parent = parent_; + // if there is no "parent_" in metadata, then it means this metadata's node + // is the root and stop printing the traceback + while (cur_parent) { + auto parent_metadata = cur_parent->metadata(); + TORCH_WARN( + "\n\n", + "Previous calculation was induced by ", + cur_parent->name(), + ". " + "Traceback of forward call that induced the previous calculation:\n", + parent_metadata->traceback_); + // get the parent of this node, if this node is a root, pyparent is simply + // null + cur_parent = parent_metadata->parent_; + } +} + +void AnomalyMetadata::assign_parent(const std::shared_ptr& parent_node) { + parent_ = parent_node; +} + +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/anomaly_mode.h b/torch/csrc/autograd/anomaly_mode.h index 013600b230fc..d7384003bc8e 100644 --- a/torch/csrc/autograd/anomaly_mode.h +++ b/torch/csrc/autograd/anomaly_mode.h @@ -24,9 +24,13 @@ struct TORCH_API AnomalyMode { struct TORCH_API AnomalyMetadata { virtual ~AnomalyMetadata(); - virtual void store_stack() = 0; - virtual void print_stack(const std::string& current_node_name) = 0; - virtual void assign_parent(const std::shared_ptr& parent_node) = 0; + virtual void store_stack(); + virtual void print_stack(const std::string& current_node_name); + virtual void assign_parent(const std::shared_ptr& parent_node); + + private: + std::string traceback_; + std::shared_ptr parent_; }; }} diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 0dde6e735d10..33623b7f4b20 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -284,7 +284,7 @@ struct TORCH_API Engine { std::shared_ptr graph_root); virtual std::unique_ptr make_anomaly_metadata() { - return nullptr; + return std::make_unique(); } // We pass cpu_ready_queue to evaluate_function, so that it knows diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index 2eb79283a088..4ee19888e4c2 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -85,44 +85,5 @@ void broadcast_coalesced( } } -PythonCommHook::PythonCommHook(py::object state, py::object hook) - : state_(std::move(state)), hook_(std::move(hook)){}; - -c10::intrusive_ptr PythonCommHook::runHook( - const GradBucket& bucket) { - py::gil_scoped_acquire acquire; - - py::object py_fut = hook_(state_, bucket); - - try { - return py_fut.cast>()->fut; - } catch (const py::cast_error& e) { - auto type = py_fut.get_type(); - auto errMsg = c10::str( - e.what(), - ". DDP communication hook's callback must return a " - "torch.futures.Future or torch._C.Future object, but got ", - type.attr("__module__").cast(), - ".", - type.attr("__qualname__").cast()); - throw std::runtime_error(errMsg); - } -} - -std::vector PythonCommHook::processFuture( - c10::IValue future_value) { - // Since we have a Python hook, future_value can be a PyObject. - if (future_value.isPyObject()) { - // We first convert it to an IValue that contains a TensorVector. - py::gil_scoped_acquire ag; - py::object obj = torch::jit::toPyObject(future_value); - auto value = torch::jit::toIValue( - obj, c10::ListType::create(c10::TensorType::get())); - - return value.toTensorVector(); - } - - return future_value.toTensorVector(); -} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 2eb626c40232..58f40f81ccb3 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,10 +1,8 @@ #pragma once -#include - #include +#include #include -#include namespace c10d { @@ -31,62 +29,59 @@ class GradBucket { return tensors_; } + std::vector& getTensorsRef() { + return tensors_; + } + private: std::vector tensors_; }; -// DDP's c10d reducer allows communication hooks defined as a sub class -// of CommHookInterface. CommHookInterface is an abstract class and can -// be used to implement both Python and CPP hooks. -struct TORCH_PYTHON_API CommHookInterface { +// Base class of both `PythonCommHook` and `CppCommHook`. +// Requires implementing 1) `runHook` method that communicates gradients +// asynchronously, and 2) `parseHookResult` method that converts the hook result +// into a tensor vector. +class TORCH_PYTHON_API CommHookInterface { public: virtual ~CommHookInterface() {} - // runHook takes a GradBucket type bucket and passes the tensors of - // this grad bucket to hook's callback. This function is called once - // the bucket is ready. The hook can perform whatever processing is - // needed and return a Future that will hold the new value of the grad - // bucket's tensors once ready. - virtual c10::intrusive_ptr runHook( - const GradBucket& bucket) = 0; - - // Once the grad bucket of Future is ready, c10d reducer will call this - // function to get the resulting tensors of the grad bucket. Then c10d - // reducer will use these tensors and copy grads to the grads of individual + // Passes the input grad bucket to the registered communication hook. + // Once the tensors in the bucket are ready, kicks off the hook asynchronously + // and returns a future that holds the communication results. + virtual c10::intrusive_ptr runHook( + GradBucket& bucket) = 0; + + // Returns the resulting tensors once the communication hook result is ready. + // The resulting tensors will then be copied to the grads of individual // parameters. - virtual std::vector processFuture(c10::IValue future_value) = 0; + virtual std::vector parseHookResult( + const c10::IValue& result) = 0; }; -// PythonCommHook enables registering a python hook to c10d reducer and is a -// sub class of CommHookInterface. -class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { +// This CppCommHook interface only requires implementing runHook method that +// potentially uses a state. +// Still need TORCH_PYTHON_API instead of TORCH_API to support Windows platform. +template +class TORCH_PYTHON_API CppCommHookInterface : public CommHookInterface { public: - // The constructor takes a state and a callable hook. Inputs are Python - // objects. The state is passed to the hook in runHook function can be used to - // maintain and update any state information that users would like to maintain - // as part of the training process. The hook can perform whatever processing - // user specified and return a Future indicating completion of any async work. - PythonCommHook(py::object state, py::object hook); - - ~PythonCommHook() override { - py::gil_scoped_acquire ag; - state_.dec_ref(); - hook_.dec_ref(); - // explicitly setting PyObject* state_ and hook_ to nullptr to prevent - // py::object's dtor to decref on the PyObject again. - // See Note [Destructing py::object] in python_ivalue.h - state_.ptr() = nullptr; - hook_.ptr() = nullptr; - } + explicit CppCommHookInterface(T& state) : state_(state) {} - c10::intrusive_ptr runHook( - const GradBucket& bucket) override; + virtual ~CppCommHookInterface() {} - std::vector processFuture(c10::IValue future_value) override; + std::vector parseHookResult(const c10::IValue& result) override { + TORCH_INTERNAL_ASSERT( + result.isTensor() || result.isTensorList(), + "expected the hook result is either a Tensor or a TensorList"); - private: - py::object state_; - py::object hook_; + if (result.isTensor()) { + return {result.toTensor()}; + } + + return result.toTensorVector(); + } + + protected: + T state_; // Not owned. }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.cpp b/torch/csrc/distributed/c10d/default_comm_hooks.cpp new file mode 100644 index 000000000000..10da31bf0b03 --- /dev/null +++ b/torch/csrc/distributed/c10d/default_comm_hooks.cpp @@ -0,0 +1,41 @@ +#include + +#include +#include +#include + +namespace c10d { + +c10::intrusive_ptr AllReduceCommHook::runHook( + GradBucket& bucket) { + auto allreduce_work = state_->allreduce(bucket.getTensorsRef()); + + auto div_by_process_group_size = [allreduce_work, this]() { + auto tensor = allreduce_work->result()[0] / state_->getSize(); + return c10::IValue(tensor); + }; + + auto fut = allreduce_work->getFuture(); + return fut->then(div_by_process_group_size, fut->elementType()); +} + +c10::intrusive_ptr FP16CompressCommHook::runHook( + GradBucket& bucket) { + auto& tensors = bucket.getTensorsRef(); + for (auto& tensor : tensors) { + tensor.copy_(tensor.to(torch::kFloat16)); + } + auto allreduce_work = state_->allreduce(tensors); + + auto decompress_and_div_by_process_group_size = [allreduce_work, this]() { + auto tensor = allreduce_work->result()[0]; + tensor.copy_(tensor.to(torch::kFloat) / state_->getSize()); + return c10::IValue(tensor); + }; + + auto fut = allreduce_work->getFuture(); + return fut->then( + decompress_and_div_by_process_group_size, fut->elementType()); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.h b/torch/csrc/distributed/c10d/default_comm_hooks.h new file mode 100644 index 000000000000..140fc2505ae4 --- /dev/null +++ b/torch/csrc/distributed/c10d/default_comm_hooks.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +namespace c10d { + +enum class BuiltinCommHookType { + ALLREDUCE = 1, + FP16_COMPRESS = 2, +}; + +class AllReduceCommHook : public CppCommHookInterface { + public: + explicit AllReduceCommHook(ProcessGroup* state) + : CppCommHookInterface(state) {} + + ~AllReduceCommHook() override {} + + c10::intrusive_ptr runHook(GradBucket& bucket) override; +}; + +class FP16CompressCommHook : public CppCommHookInterface { + public: + explicit FP16CompressCommHook(ProcessGroup* state) + : CppCommHookInterface(state) {} + + ~FP16CompressCommHook() override {} + + c10::intrusive_ptr runHook(GradBucket& bucket) override; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 47a3ebabe941..7568e9979530 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -123,19 +124,25 @@ class PythonStore : public ::c10d::Store { } }; -// This method is called from DDP's Python API. Its inputs are -// a c10d reducer object, state, and callable comm_hook. State and -// comm_hook inputs are Python objects and this function creates a -// c10d PythonCommHook object using these inputs. It later calls -// register_comm_hook function of the reducer input to register that -// PythonCommHook object. +// Called from DDP's Python API to create a c10d Python comm hook object. +// The input state and callable comm_hook are Python objects. It later calls +// register_comm_hook function of the reducer input to register the hook. void _register_comm_hook( ::c10d::Reducer& reducer, py::object state, py::object comm_hook) { reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>( std::move(state), std::move(comm_hook))); -}; +} + +// Called from DDP's Python API to create a c10d C++ comm hook. +// The input is an enum hook type. It later calls register_builtin_comm_hook +// function of the reducer input to set the hook type. +void _register_builtin_comm_hook( + ::c10d::Reducer& reducer, + ::c10d::BuiltinCommHookType comm_hook_type) { + reducer.register_builtin_comm_hook(comm_hook_type); +} PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); @@ -146,12 +153,19 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { auto module = py::handle(c10d_module).cast(); - module.def( - "_register_comm_hook", - &_register_comm_hook, - py::arg("reducer"), - py::arg("state"), - py::arg("comm_hook")); + module + .def( + "_register_comm_hook", + &_register_comm_hook, + py::arg("reducer"), + py::arg("state"), + py::arg("comm_hook"), + py::call_guard()) + .def( + "_register_builtin_comm_hook", + &_register_builtin_comm_hook, + py::arg("reducer"), + py::arg("comm_hook_type")); shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket") .def(py::init&>(), py::arg("tensors")) @@ -167,6 +181,11 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { a single tensor. )"); + py::enum_<::c10d::BuiltinCommHookType>(module, "BuiltinCommHookType", R"( +An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_COMPRESS``.)") + .value("ALLREDUCE", ::c10d::BuiltinCommHookType::ALLREDUCE) + .value("FP16_COMPRESS", ::c10d::BuiltinCommHookType::FP16_COMPRESS); + shared_ptr_class_<::c10d::Reducer>(module, "Reducer") .def( py::init< diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp new file mode 100644 index 000000000000..6b25018d38a3 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -0,0 +1,60 @@ +#include + +#include +#include +#include +#include + +namespace c10d { + +PythonCommHook::~PythonCommHook() { + py::gil_scoped_acquire ag; + state_.dec_ref(); + hook_.dec_ref(); + // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + state_.ptr() = nullptr; + hook_.ptr() = nullptr; +} + +c10::intrusive_ptr PythonCommHook::runHook( + GradBucket& bucket) { + py::gil_scoped_acquire acquire; + + py::object py_fut = hook_(state_, bucket); + + try { + return py_fut.cast>()->fut; + } catch (const py::cast_error& e) { + auto type = py_fut.get_type(); + auto errMsg = c10::str( + e.what(), + ". DDP communication hook's callback must return a " + "torch.futures.Future or torch._C.Future object, but got ", + type.attr("__module__").cast(), + ".", + type.attr("__qualname__").cast()); + throw std::runtime_error(errMsg); + } +} + +std::vector PythonCommHook::parseHookResult( + const c10::IValue& result) { + TORCH_INTERNAL_ASSERT( + result.isPyObject() || result.isTensorList(), + "expected the hook result is either a PyObject or TensorList"); + + if (result.isPyObject()) { + py::gil_scoped_acquire ag; + py::object obj = torch::jit::toPyObject(result); + auto value = torch::jit::toIValue( + obj, c10::ListType::create(c10::TensorType::get())); + + return value.toTensorVector(); + } + + return result.toTensorVector(); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_comm_hook.h b/torch/csrc/distributed/c10d/python_comm_hook.h new file mode 100644 index 000000000000..e38ba096460f --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace c10d { + +class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { + public: + // Takes a state and a callable hook. The inputs are Python objects. + // The state is passed to the hook in runHook method, and it can be used to + // maintain and update any state information during the execution of the hook. + // The hook performs user-specified processing and returns a future indicating + // asychronous communication of gradients. + PythonCommHook(py::object state, py::object hook) + : state_(std::move(state)), hook_(std::move(hook)) {} + + ~PythonCommHook() override; + + c10::intrusive_ptr runHook(GradBucket& bucket) override; + + std::vector parseHookResult(const c10::IValue& result) override; + + private: + // Only needed for stateful communication. + py::object state_; + py::object hook_; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index f146a2a9be45..ba114560b144 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -119,8 +119,19 @@ Reducer::Reducer( // This is used later on when the autograd graph is traversed // to check for parameters for which no gradient is computed, if // find_unused_parameters=True. + // We maintain a mapping of gradient accumulator to vector of variables, + // since multiple parameters may share the same grad accumulator. if (find_unused_parameters_) { - gradAccToVariableMap_[grad_accumulator.get()] = index; + auto gradAcc = gradAccToVariablesMap_.find(grad_accumulator.get()); + if (gradAcc == gradAccToVariablesMap_.end()) { + std::vector indexVec{index}; + gradAccToVariablesMap_[grad_accumulator.get()] = + std::move(indexVec); + } else { + // Scenario where we have indices whose corresponding parameters + // share the same grad accumulator. + gradAcc->second.push_back(index); + } } // The gradient accumulator is stored as weak_ptr in the autograd @@ -197,7 +208,8 @@ Reducer::Reducer( // used for algorithms like Gradient Compression/GossipGrad. This hook can be // registered from Python API using `register_comm_hook`. `PythonCommHook` // enables registering a Python hook and is a subclass of `CommHookInterface`. -// `CommHookInterface` can be used to implement CPP hooks in the future. +// Additionally, there are also some built-in C++ hook implementations that can +// be specified by calling `register_builtin_comm_hook` from Python API. Reducer::~Reducer() noexcept(false) { // Remove all hooks on variables registered by this Reducer. This is necessary @@ -700,7 +712,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { if (comm_hook_ == nullptr) { bucket.work = process_group_->allreduce(tensors); } else { - bucket.future_work = comm_hook_->runHook(GradBucket(tensors)); + GradBucket grad_bucket(tensors); + bucket.future_work = comm_hook_->runHook(grad_bucket); } } } @@ -996,14 +1009,15 @@ void Reducer::prepare_for_backward( } // Find accumulator functions that don't show up in this graph. - for (const auto& it : gradAccToVariableMap_) { + for (const auto& it : gradAccToVariablesMap_) { // If the accumulator function is present in the graph, we know // a gradient will be computed for the corresponding parameter. - if (seen.count(it.first) > 0) { - continue; + if (seen.count(it.first) == 0) { + auto& indices = it.second; + unused_parameters_.reserve(unused_parameters_.size() + indices.size()); + unused_parameters_.insert( + unused_parameters_.end(), indices.begin(), indices.end()); } - - unused_parameters_.push_back(it.second); } } @@ -1165,7 +1179,7 @@ void Reducer::finalize_backward() { bucket.future_work->wait(); auto future_result = - comm_hook_->processFuture(bucket.future_work->value()); + comm_hook_->parseHookResult(bucket.future_work->value()); for (size_t i = 0; i < future_result.size(); i++) { auto& replica = bucket.replicas[i]; @@ -1356,7 +1370,8 @@ bool Reducer::rebuild_buckets() { // See Note [DDP Communication Hook] void Reducer::register_comm_hook(std::unique_ptr iface) { TORCH_CHECK( - comm_hook_ == nullptr, "register_comm_hook can only be called once."); + comm_hook_ == nullptr, + "register_comm_hook or register_builtin_comm_hook can only be called once."); // TODO(@sinannasir): Single-process multiple-device mode support for DDP // communication hook. Related to GH Issue #42542. TORCH_CHECK( @@ -1366,6 +1381,33 @@ void Reducer::register_comm_hook(std::unique_ptr iface) { comm_hook_ = std::move(iface); } +// See Note [DDP Communication Hook] +void Reducer::register_builtin_comm_hook( + c10d::BuiltinCommHookType comm_hook_type) { + TORCH_CHECK( + comm_hook_ == nullptr, + "register_builtin_comm_hook or register_comm_hook can only be called once."); + TORCH_CHECK( + replicas_.size() == 1, + "Communication hook does not support single-process multiple-device mode."); + + switch (comm_hook_type) { + case c10d::BuiltinCommHookType::ALLREDUCE: + comm_hook_ = + std::make_unique(process_group_.get()); + LOG(INFO) << "Built-in communication hook ALLREDUCE is registered."; + break; + case c10d::BuiltinCommHookType::FP16_COMPRESS: + comm_hook_ = + std::make_unique(process_group_.get()); + LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered."; + break; + default: + TORCH_WARN_ONCE( + "Unknown built-in DDP comm hook type is provided. No comm hook will be used."); + } +} + void Reducer::ensure_prior_reduction_finished() { // Check that any prior reduction has finished. // The variable `require_finalize_` is true until all gradients diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index 25f81857d101..c3125e2a3b36 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -12,6 +12,7 @@ #include #include #include +#include namespace c10d { @@ -58,8 +59,14 @@ class Reducer { // Registers a hook to the reducer. The hook is `CommHookInterface` // type to allow both Python and CPP hooks. This function can only // be called once before calling backward. + // Cannot combine with the call of `register_builtin_comm_hook`. void register_comm_hook(std::unique_ptr iface); + // Registers a built-in C++ comm hook to the reducer. This function can only + // be called once before calling backward. + // Cannot combine with the call of `register_comm_hook`. + void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type); + // Returns a vector of tensors in each bucket in sequential order. std::vector> get_bucket_tensors() const; @@ -122,8 +129,8 @@ class Reducer { std::vector>> grad_accumulators_; - std::unordered_map - gradAccToVariableMap_; + std::unordered_map> + gradAccToVariablesMap_; std::vector>> hooks_; diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index d278ac59ca75..2f608bd6fd33 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -393,6 +393,43 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { Set future that is completed when the profiling event corresponding to the creation of this RRef on the remote node has been recorded. )") + .def( + "backward", + [](PyRRef& self, + int64_t dist_autograd_ctx_id, + bool retain_graph) { + self.backward(dist_autograd_ctx_id, retain_graph); + }, + py::arg("dist_autograd_ctx_id") = -1, + py::arg("retain_graph") = false, + py::call_guard(), + R"( + Runs the backward pass using the RRef as the root of the + backward pass. If ``dist_autograd_ctx_id`` is provided, + we perform a distributed backward pass using the provided + ctx_id starting from the owner of the RRef. In this case, + :meth:`~torch.distributed.autograd.get_gradients` should be + used to retrieve the gradients. If ``dist_autograd_ctx_id`` + is ``None``, it is assumed that this is a local autograd graph + and we only perform a local backward pass. The value of the + RRef is expected to be a scalar Tensor. + + Arguments: + dist_autograd_ctx_id (int, optional): The distributed + autograd context id for which we should retrieve the + gradients (default: -1). + retain_graph(bool, optional): If ``False``, the graph used to + compute the grad will be freed. Note that in nearly all + cases setting this option to ``True`` is not needed and + often can be worked around in a much more efficient way. + Usually, you need to set this to ``True`` to run backward + multiple times (default: False). + + Example:: + >>> import torch.distributed.autograd as dist_autograd + >>> with dist_autograd.context() as context_id: + >>> rref.backward(context_id) + )") // not releasing GIL to avoid context switch .def("__repr__", &PyRRef::str); diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 823e21d20b4b..6243d0b78e83 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -1,5 +1,7 @@ #include +#include +#include #include #include #include @@ -283,6 +285,25 @@ c10::IValue PyRRef::toIValue() const { return IValue(rrefPtr); } +void PyRRef::backward(int64_t dist_autograd_ctx_id, bool retain_graph) { + if (rref_->isOwner()) { + const auto& value = + c10::static_intrusive_pointer_cast(rref_)->getValue(); + TORCH_CHECK( + value.isTensor(), "RRef should contain a tensor for .backward()"); + auto root = value.toTensor(); + + if (dist_autograd_ctx_id == -1) { + torch::autograd::backward({root}); + } else { + torch::distributed::autograd::backward( + dist_autograd_ctx_id, {value.toTensor()}, retain_graph); + } + } else { + // TODO + } +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h index 3cc7ab73f019..6af75cd45c3e 100644 --- a/torch/csrc/distributed/rpc/py_rref.h +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -51,6 +51,9 @@ class PYBIND11_EXPORT PyRRef { // get the type of the data object referenced by this RRef. py::object getRRefType(); + // Run the backward pass with the RRef as the root. + void backward(int64_t dist_autograd_ctx_id, bool retain_graph); + private: c10::intrusive_ptr rref_; c10::optional> profilingFuture_; diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index 17c92cb14023..596c19e6ba1b 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -2,11 +2,121 @@ #include #include #include +#include #include namespace torch { namespace jit { +// Get all types that are shared in the module hierarchy rooted at \p mod. +std::unordered_set getSharedModuleTypes(Module& mod) { + // Maintain a set of all TypePtrs. + std::unordered_set types; + // Maintain another set of TypePtrs that have been encountered more than once. + std::unordered_set duplicate_types; + + // Iterate over all modules in the hierarchy, including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + if (types.count(module_type) > 0) { + duplicate_types.insert(module_type); + } + + types.insert(module_type); + } + + return duplicate_types; +} + +// Selectively lower \p mod to a backend. \p to_backend +// is called to lower modules. \p modules_to_lower contains +// qualified names of submodules of \p mod that should be lowered. +void toBackendSelectiveImpl( + Module& mod, + const py::function& to_backend, + const std::vector& modules_to_lower, + const std::unordered_set& duplicate_types) { + // This map will be used later to remap types in ancestor module graphs for + // all lowered submodules. + std::unordered_map type_remap; + + // For each module that should be lowered: + for (const auto& module_to_lower : modules_to_lower) { + // Use QualifiedName to parse the qualified module names. + c10::QualifiedName qual_module_name(module_to_lower); + auto& atoms = qual_module_name.atoms(); + + // Search through the module hierarchy using the atoms of + // qual_module_name until current points to the module to + // be lowered and parent points to its parent. + Module current = mod; + Module parent; + + for (size_t i = 0, e = atoms.size(); i < e; ++i) { + IValue submodule = current.attr(atoms[i]); + if (submodule.isModule()) { + if (i == e - 1) { + parent = current; + } + current = submodule.toModule(); + } else { + std::stringstream err; + err << "Attribute named " << atoms[i] << " is not a Module"; + throw std::runtime_error(err.str()); + } + } + + // Check that the parent type is not shared and therefore can be edited. + if (duplicate_types.count(parent.type()) > 0) { + throw py::cast_error(c10::str( + "Selective lowering is only supported for module hierarchies with unique types for selected modules; ", + parent.type()->repr_str(), + " is shared")); + } + + // Call to_backend on the module that needs to be lowered. It needs to be + // wrapped before doing so because _to_jit_backend accepts wrapped modules. + // The result needs to be unwrapped in order to access its type below. + auto lowered_submodule = + py::cast(to_backend(py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(current)) + .attr("_c")); + + // Adjust the parent's type so that the type of the submodule matches + // the type of lowered_submodule. + auto parent_type = parent.type(); + + parent_type->unsafeChangeAttributeType( + atoms.back(), lowered_submodule.type()); + parent.setattr(atoms.back(), lowered_submodule._ivalue()); + + // Record the type mapping from old type -> lowered type. + type_remap[current.type()] = lowered_submodule.type(); + } + + // Having lowered all of the modules that needed to be lowered, remap types in + // all graphs in the hierarchy so that the graphs all use the new lowered + // type. + auto type_remap_fn = [&type_remap](TypePtr in) { + auto it = type_remap.find(in); + if (it == type_remap.end()) + return in; + return it->second; + }; + + // modules() iterates over all modules in the hierarchy including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + for (auto& fn : module_type->methods()) { + auto method = module.get_method(fn->name()); + auto graph = method.graph(); + graph->remapTypes(type_remap_fn); + auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); + fn->setSchema(new_schema); + } + } +} + void initJitBackendBindings(PyObject* module) { // Bind a function for lowering to each JIT backend. The name of the backend // must be the first argument. For example, to lower a Module to @@ -124,7 +234,7 @@ void initJitBackendBindings(PyObject* module) { static const auto method_ct = CodeTemplate(R"( def $method(self${,def_inputs}): typed_inputs: List[Any] = [${fwd_inputs,}] - $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs) + $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs) ${refine,} return $ret )"); @@ -181,7 +291,9 @@ void initJitBackendBindings(PyObject* module) { out_ss << "_0"; type_check_ss << "assert isinstance(_0, "; - if (auto out_tuple_ty = out_ty->cast()) { + auto out_tuple_ty = out_ty->cast(); + + if (out_tuple_ty) { auto tuple_elements = out_tuple_ty->elements(); type_check_ss << tuple_elements[0]->str() << ")"; type_checks.emplace_back(type_check_ss.str()); @@ -201,6 +313,14 @@ void initJitBackendBindings(PyObject* module) { method_te.v("def_inputs", def_inputs); method_te.v("fwd_inputs", fwd_inputs); method_te.v("refine", type_checks); + method_te.s("unpack", out_ss.str()); + + // If the output type is a single element tuple then add an extra comma + // to ensure the final output maintains this type. + if (out_tuple_ty && out_tuple_ty->elements().size() == 1) { + out_ss << ","; + } + method_te.s("ret", out_ss.str()); loweredModule.define( @@ -234,6 +354,32 @@ void initJitBackendBindings(PyObject* module) { py::cast(orig_module.attr("_c")), method_compile_spec)); }); + + m.def( + "_jit_to_backend_selective", + [=](py::handle orig_module, + const py::function& to_backend, + const std::vector& modules_to_lower) { + if (auto original_module = + as_module(py::cast(orig_module))) { + // Clone the Module to avoid editing types that are shared with + // Modules in other instances outside this hierarchy. + Module& mod = original_module.value(); + auto cloned_mod = mod.clone(); + // Get all shared module types. Type sharing is only a problem if the + // parent modules of the ones to lower are in this set. + auto shared_types = getSharedModuleTypes(cloned_mod); + toBackendSelectiveImpl( + cloned_mod, to_backend, modules_to_lower, shared_types); + // Wrap the result in a RecursiveScriptModule because that's what + // the caller passed in. + return py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(cloned_mod); + } + + throw py::cast_error(c10::str( + "Object ", py::str(orig_module), " is not a ScriptModule")); + }); } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/convert_to_ssa.cpp b/torch/csrc/jit/frontend/convert_to_ssa.cpp index 10109aa55824..1dd61c260bd6 100644 --- a/torch/csrc/jit/frontend/convert_to_ssa.cpp +++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp @@ -5,19 +5,20 @@ #include #include #include -#include namespace torch { namespace jit { // At the beginning of the pass the Graph has already undergone type checking, // and writes or reads to a variable are emitted as Loads and Stores in the -// graph. a = 1 print(a) is represented as: -// -// %a.1 : int = prim::Constant[value=1]() -// prim::Store[name="a"](%a.1) -// %a : int = prim::Load[name="a"]() -// prim::Print(%a) +// graph. +// a = 1 +// print(a) +// is represented as: +// %a.1 : int = prim::Constant[value=1]() +// prim::Store[name="a"](%a.1) +// %a : int = prim::Load[name="a"]() +// prim::Print(%a) // // First, this pass recursively adds the Loads & Stores to control flow nodes // Then the graph is converted to SSA form. @@ -149,7 +150,7 @@ struct ControlFlowLoadStores { case prim::Loop: { addLoopLoadStores(n); } break; - case prim::Function: { + case prim::Closure: { for (auto b : n->blocks()) { addControlFlowLoadStores(b); } @@ -157,7 +158,7 @@ struct ControlFlowLoadStores { case prim::Store: { environment_stack->setVar(n->s(attr::name), n->input()->type()); } break; - case prim::LocalVariableScope: { + case prim::ListComprehensionScope: { addControlFlowLoadStores(n->blocks().at(0)); } break; } @@ -204,7 +205,7 @@ struct EraseLoadStores { n->output()->replaceAllUsesWith(var); n->destroy(); } break; - case prim::LocalVariableScope: { + case prim::ListComprehensionScope: { // writes within a local variable scope do not leak into // the rest of the graph auto body = n->blocks().at(0); @@ -279,7 +280,7 @@ struct LoopContinuations { assignExitContinuations(n->blocks().at(0)); assignExitContinuations(n->blocks().at(1)); } break; - case prim::Function: { + case prim::Closure: { LoopContinuations closure_block; closure_block.run(n->blocks().at(0)); } break; diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 3126d78c3bd2..e14cb6428890 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -119,7 +119,7 @@ struct ExitTransformer { static bool isGraphOrClosureBlock(Block* block) { return block->owningNode() == nullptr || - owningNodeKind(block) == prim::Function; + owningNodeKind(block) == prim::Closure; } static void removeOutputs(Block* b) { @@ -425,7 +425,7 @@ struct ExitTransformer { case prim::With: { exit_pair = transformWith(node); } break; - case prim::Function: { + case prim::Closure: { // exits of closure declaration stay local to the closure transformExits(node->blocks().at(0)); } break; diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 5294e02e739f..a4b239418cfb 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -859,9 +860,12 @@ struct to_ir { return emitStatements(statements.begin(), statements.end()); } - // XXX - right now closures are used _only_ for defining gradients internally + // XXX: Right now closures are not generically implemented and are only used + // as an intermediate form for special tasks, like defining gradients or + // forked functions. + // // There are several unfinished aspects that make them unusable generally - // 1. We do not have a type, ivalue, operator to represent prim::Function, so + // 1. We do not have a type, ivalue, operator to represent prim::Closure, so // closure_node has type None // 2. There is no export logic for it yet, so it cannot be // exported/python_printed @@ -870,9 +874,19 @@ struct to_ir { // the changes to those variables will just get forgotten. // 4. There is no parsing support in frontend.py, this is intentional since it // prevents people from accidentally using this feature. + // + // This function leaves in the graph something like: + // + // %2 : None = prim::Closure() + // block0(): + // %1 : Tensor = prim::DoSomething(%0) + // -> (%1) + // + // A separate pass is required to erase this closure and replace it with + // something actually executable (see liftClosure and inlineForkedClosure). std::shared_ptr emitClosure( const std::function& emit_body) { - Node* closure_node = graph->insertNode(graph->create(prim::Function, 1)); + Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1)); // it is not a real thing yet, so just say the type is None closure_node->output()->setType(NoneType::get()); Block* block = closure_node->addBlock(); @@ -1262,7 +1276,7 @@ struct to_ir { // comprehension introduces it's own scope. no variable assigned // leaks into the rest of the graph Node* n = - graph->insertNode(create(prim::LocalVariableScope, lc.range(), 0)); + graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0)); auto* comprehension_block = n->addBlock(); pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); @@ -2094,8 +2108,8 @@ struct to_ir { stmt.range(), *method.graph(), getAugOp(stmt, lhs->type()), - /*inputs=*/{lhs, rhs}, - /*attributes=*/{}, + /*args=*/{lhs, rhs}, + /*kwargs=*/{}, /*self=*/c10::nullopt); } } @@ -2665,9 +2679,9 @@ struct to_ir { if (auto special_form = dynamic_cast(sv.get())) { return emitApplySpecialForm(special_form->form(), apply, type_hint); } - auto inputs = getNamedValues(apply.inputs(), true); - auto attributes = emitAttributes(apply.attributes()); - return sv->call(loc, method, inputs, attributes, n_binders); + auto args = getNamedValues(apply.inputs(), true); + auto kwargs = emitAttributes(apply.attributes()); + return sv->call(loc, method, args, kwargs, n_binders); } // this function handles expressions that look like apply statements @@ -2688,9 +2702,9 @@ struct to_ir { } auto forked = emitSugaredExpr(Expr(trees[0]), 1); TreeList sliced_trees(trees.begin() + 1, trees.end()); - auto inputs = getNamedValues(sliced_trees, true); - auto attributes = emitAttributes(apply.attributes()); - return emitForkExpr(apply.range(), forked, inputs, attributes); + auto args = getNamedValues(sliced_trees, true); + auto kwargs = emitAttributes(apply.attributes()); + return emitForkExpr(apply.range(), forked, args, kwargs); } case prim::annotate: { checkApplyNumInputs(apply, 2); @@ -2932,7 +2946,7 @@ struct to_ir { return emitApplyExpr(apply, n_binders, type_hint); } break; case TK_SUBSCRIPT: { - return emitSubscript(Subscript(tree)); + return emitSubscript(Subscript(tree), type_hint); } break; default: return std::make_shared(emitSimpleExpr(tree, type_hint)); @@ -2965,11 +2979,15 @@ struct to_ir { return graph->insertConstant(maybe_out_stack->at(0), tree->range()); } + /** + * Emit a fork expression, of the form: + * torch.jit.fork(forked, *args, **kwargs) + */ std::shared_ptr emitForkExpr( SourceRange loc, const std::shared_ptr& forked, - at::ArrayRef inputs, - at::ArrayRef attributes) { + at::ArrayRef args, + at::ArrayRef kwargs) { auto g = method.graph(); Node* fork_node; TypePtr out_type; @@ -2989,8 +3007,7 @@ struct to_ir { fork_node->addInput(closure_output); } else { auto emit_closure_body = [&](Block* closure_block) { - auto fn_sugared_output = - forked->call(loc, method, inputs, attributes, 1); + auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1); auto fn_simple_output = fn_sugared_output->asValue(loc, method); closure_block->registerOutput(fn_simple_output); out_type = fn_simple_output->type(); @@ -3788,7 +3805,9 @@ struct to_ir { ->output(); } - std::shared_ptr emitSubscript(const Subscript& subscript) { + std::shared_ptr emitSubscript( + const Subscript& subscript, + TypePtr type_hint = nullptr) { const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); const List& subscript_exprs = subscript.subscript_exprs(); const SourceRange& range = subscript.range(); @@ -3858,7 +3877,7 @@ struct to_ir { return std::make_shared( emitMultidimSlicing(range, sliceable, subscript_exprs)); } else { - return sv->getitem(range, method, idx); + return sv->getitem(range, method, idx, std::move(type_hint)); } } } diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index fb2e0f20f380..9fd3973f9b3d 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -584,8 +584,8 @@ Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self) { const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); @@ -620,7 +620,7 @@ Value* emitBuiltinCall( throw error; } - auto matched = matchSchemas(schemas, loc, graph, inputs, attributes, self); + auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self); if (matched.first < variants.size()) { return emitBuiltinNode(matched.second, loc, graph, name); diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 88fe23a9682d..83e34bb33ae5 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -23,7 +23,7 @@ TORCH_API MatchedSchema matchSchema( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API std::pair matchSchemas( @@ -31,7 +31,7 @@ TORCH_API std::pair matchSchemas( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt, bool render_errors = false); @@ -43,8 +43,8 @@ TORCH_API Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API c10::optional findInputWithName( diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 69e86716f72e..8810a5a62019 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -17,14 +18,14 @@ struct NoneValue : SugaredValue { std::shared_ptr PrintValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); - if (!attributes.empty()) + if (!kwargs.empty()) throw ErrorReport(loc) << "print doesn't accept any keyword arguments"; - std::vector lowered_inputs = toValues(*m.graph(), inputs); + std::vector lowered_inputs = toValues(*m.graph(), args); g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc)); return std::make_shared(); } @@ -46,11 +47,11 @@ builtin_cast_method_to_scalar_type() { std::shared_ptr BuiltinFunction::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { return std::make_shared( - emitBuiltinCall(loc, *m.graph(), symbol, inputs, attributes, self)); + emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self)); } // older versions of gcc/clang have a bug where enums can't be used as keys @@ -322,14 +323,14 @@ void SimpleValue::setAttr( std::shared_ptr SimpleValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // allow our 'fake' closures to be called, used for fork serialization // at the moment, but can be expanded later Node* self = getValue()->node(); if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 && - self->inputs().at(0)->node()->kind() == prim::Function) { + self->inputs().at(0)->node()->kind() == prim::Closure) { std::shared_ptr graph = self->inputs().at(0)->node()->g(attr::Subgraph); Value* context = self->inputs().at(1); @@ -348,16 +349,15 @@ std::shared_ptr SimpleValue::call( auto ret = StrongFunctionPtr(std::move(cu), fn); std::vector ctx_inputs = {close_context}; - ctx_inputs.insert(ctx_inputs.end(), inputs.begin(), inputs.end()); - return FunctionValue(ret).call(loc, m, ctx_inputs, attributes, n_binders); + ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end()); + return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders); } if (auto class_type = getValue()->type()->cast()) { - return attr(loc, m, "__call__") - ->call(loc, m, inputs, attributes, n_binders); + return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders); } - return SugaredValue::call(loc, m, inputs, attributes, n_binders); + return SugaredValue::call(loc, m, args, kwargs, n_binders); } Value* SimpleValue::len(const SourceRange& loc, Function& m) { @@ -377,7 +377,8 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr SimpleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { Value* val = getValue(); TypePtr val_type = val->type(); Graph& g = *m.graph(); @@ -393,6 +394,17 @@ SugaredValuePtr SimpleValue::getitem( return std::make_shared( g.insert(aten::select, {val, 0, idx}, {}, loc)); } else if (auto class_type = val_type->cast()) { + // Check if this is an indexing operation enabled by a type hint. + // The ModuleDict has already been checked during IR generation to make + // sure its contents implement the module interface referred to by + // type_hint. + if (class_type->is_module() && type_hint) { + auto res = g.insert(prim::ModuleDictIndex, {val, idx}, {}, loc); + res->setType(type_hint); + return std::make_shared(res); + } + + // Defer to the __getitem__ attr on the class. return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1); } else { throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" @@ -485,7 +497,8 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr RangeValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (has_only_end_) { return std::make_shared(idx); } else { @@ -535,7 +548,8 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) { SugaredValuePtr IterableTree::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { std::vector child_items; for (const SugaredValuePtr& child : children_) { child_items.emplace_back(child->getitem(loc, m, idx)); @@ -569,27 +583,27 @@ void IterableTree::addChild( std::shared_ptr MagicMethod::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - if (inputs.size() > 0) { - Value* self = inputs[0].value(*m.graph()); + if (args.size() > 0) { + Value* self = args[0].value(*m.graph()); if (auto class_ptr = self->type()->cast()) { return SimpleValue(self) .attr(loc, m, desugared_name_) - ->call(loc, m, inputs.slice(1), attributes, n_binders); + ->call(loc, m, args.slice(1), kwargs, n_binders); } } TORCH_INTERNAL_ASSERT(base_value_); - return base_value_->call(loc, m, inputs, attributes, n_binders); + return base_value_->call(loc, m, args, kwargs, n_binders); } std::shared_ptr ClassValue::call( const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { AT_ASSERT(n_binders <= 1); @@ -602,7 +616,7 @@ std::shared_ptr ClassValue::call( } // Call the init function - MethodValue(self, "__init__").call(loc, m, inputs, attributes, n_binders); + MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders); return std::make_shared(self); } @@ -621,15 +635,15 @@ std::shared_ptr ClassValue::attr( std::shared_ptr NamedTupleConstructor::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); auto schema = type_->schema(); TORCH_INTERNAL_ASSERT(schema); auto qualname = type_->name(); - auto matched_schema = matchSchema(*schema, loc, g, inputs, attributes); + auto matched_schema = matchSchema(*schema, loc, g, args, kwargs); auto self = g.insertNode( diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 3523523f5c23..28a18aceda49 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -84,8 +84,8 @@ struct TORCH_API SugaredValue const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // n_binders is always set to the number of variables an expression is // syntactically bound to: @@ -139,7 +139,8 @@ struct TORCH_API SugaredValue virtual std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint = nullptr) { throw ErrorReport(loc) << "'" << kind() << "'" << " object is not subscriptable"; } @@ -181,8 +182,8 @@ struct TORCH_API SimpleValue : public SugaredValue { const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr iter(const SourceRange& loc, Function& m) @@ -193,8 +194,11 @@ struct TORCH_API SimpleValue : public SugaredValue { } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: Value* value_; @@ -215,8 +219,8 @@ struct TORCH_API BuiltinFunction : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef attributes, - at::ArrayRef inputs, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; // try to create this builtin but if it doesn't exist or the self argument @@ -251,8 +255,11 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { return "Tuple"; } - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override { + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override { if (!(idx->type()->cast() && toIValue(idx))) { throw ErrorReport(loc) << "Expected integer literal for index. " @@ -332,8 +339,8 @@ struct TORCH_API ClassValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr attr( @@ -354,8 +361,8 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override { @@ -384,8 +391,8 @@ struct FunctionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { std::vector schemas; for (Function* callee : callees_) { @@ -398,7 +405,7 @@ struct FunctionValue : public SugaredValue { } schemas.push_back(&callee->getSchema()); } - auto match = matchSchemas(schemas, loc, *f.graph(), inputs, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs); Value* output = f.graph()->insertFunctionCall(callees_[match.first], match.second); output->node()->setSourceRange(loc); @@ -417,7 +424,7 @@ struct FunctionValue : public SugaredValue { struct TORCH_API ClosureValue : public SugaredValue { ClosureValue(Value* value) : value_(value) { - TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Function); + TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure); } std::string kind() const override { return "closure"; @@ -442,11 +449,11 @@ struct MethodValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - std::vector inputsWithSelf = {self_}; - inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end()); + std::vector argsWithSelf = {self_}; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); std::vector schemas; for (const std::string& method_name : method_names_) { if (auto class_type = self_->type()->cast()) { @@ -466,8 +473,7 @@ struct MethodValue : public SugaredValue { false, "method constructed that is not a class or interface"); } } - auto match = - matchSchemas(schemas, loc, *f.graph(), inputsWithSelf, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs); Value* output = f.graph()->insertMethodCall(method_names_[match.first], match.second); output->node()->setSourceRange(loc); @@ -486,8 +492,8 @@ struct TORCH_API PrintValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; @@ -500,16 +506,16 @@ struct TORCH_API CastValue : public BuiltinFunction { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() == 1 && attributes.size() == 0) { - auto v = inputs[0].value(*m.graph()); + if (args.size() == 1 && kwargs.size() == 0) { + auto v = args[0].value(*m.graph()); if (v->type()->isSubtypeOf(type_)) { return std::make_shared(v); } } - return BuiltinFunction::call(loc, m, inputs, attributes, n_binders); + return BuiltinFunction::call(loc, m, args, kwargs, n_binders); } private: @@ -527,17 +533,17 @@ struct TORCH_API TensorCastValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - TORCH_INTERNAL_ASSERT(inputs.size() == 0 && attributes.size() == 0); + TORCH_INTERNAL_ASSERT(args.size() == 0 && kwargs.size() == 0); Value* dtype_const = m.graph()->insertConstant(dtype_, loc); - std::vector kwargs{self_, - NamedValue(loc, "dtype", dtype_const)}; + std::vector kwargs_{self_, + NamedValue(loc, "dtype", dtype_const)}; Value* casted_val = m.graph()->insert( /*opname=*/Symbol::fromQualString("aten::to"), - /*args=*/inputs, - /*kwargs=*/kwargs, + /*args=*/args, + /*kwargs=*/kwargs_, /*range=*/loc); return std::make_shared(casted_val); } @@ -560,8 +566,8 @@ struct TORCH_API MagicMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -604,8 +610,11 @@ struct TORCH_API RangeValue : SugaredValue { return "range"; } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; std::shared_ptr iter(const SourceRange& loc, Function& m) override; @@ -680,8 +689,11 @@ struct TORCH_API IterableTree : SugaredValue { std::vector get_base_iterables(); Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: c10::optional unroll_length_ = c10::nullopt; @@ -735,11 +747,11 @@ struct TORCH_API ExceptionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, + at::ArrayRef args, at::ArrayRef /*attributes*/, size_t /*n_binders*/) override { auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc); - for (auto& input : inputs) { + for (auto& input : args) { auto input_str = input.value(*m.graph()); if (!input_str->type()->isSubtypeOf(StringType::get())) { input_str = diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index bb5872f35f4f..b055d29164a5 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -494,7 +494,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::MMBatchSide: case prim::BroadcastSizes: case prim::ChunkSizes: - case prim::Function: + case prim::Closure: case prim::CreateObject: case prim::tolist: return analyzeCreator(node); diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index c7f41b902ad6..96b07612c903 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -86,8 +86,7 @@ void print_unsupported_ops_and_throw( TORCH_CHECK( false, "Following ops cannot be found. ", - "May need to add them explicitly to the selective build operator whitelist, ", - "or re-run the export_opnames to update the whitelist:", + "Check fburl.com/missing_ops for the fix.", error_message); } diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index c3285f2e2426..2981daa0006a 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -89,7 +89,7 @@ namespace { std::unordered_set skip_list = { prim::If, prim::Loop, - prim::Function, + prim::Closure, prim::Constant, prim::AutogradZero, prim::Uninitialized, diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 6b5beb4372a8..76b6f1d234ba 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -212,6 +212,13 @@ class AttributePropagator { for (Block* sub_block : n->blocks()) { blocks.push(sub_block); } + + // Modules with prim::ModuleDictIndex cannot be frozen because they + // return InterfaceTypes. + TORCH_CHECK( + n->kind() != prim::ModuleDictIndex, + "Freezing modules containing prim::ModuleDictIndex is not supported"); + if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) { // By default if interface attributes are present then fail freezing. // If freezingInterfaces is on then Interfaces are folded similarly diff --git a/torch/csrc/jit/passes/inline_forked_closures.cpp b/torch/csrc/jit/passes/inline_forked_closures.cpp index ea5a977e4091..e97d71e32249 100644 --- a/torch/csrc/jit/passes/inline_forked_closures.cpp +++ b/torch/csrc/jit/passes/inline_forked_closures.cpp @@ -19,7 +19,7 @@ void inlineForkedClosure(Node* fork_closure) { Node* function_context_node = fork_closure->input()->node(); if (function_context_node->inputs().size() != 2 || - function_context_node->inputs().at(0)->node()->kind() != prim::Function || + function_context_node->inputs().at(0)->node()->kind() != prim::Closure || function_context_node->inputs().at(1)->node()->kind() != prim::TupleConstruct) { throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value"; diff --git a/torch/csrc/jit/passes/lift_closures.cpp b/torch/csrc/jit/passes/lift_closures.cpp index 82e6f2216681..4f5941ce8afb 100644 --- a/torch/csrc/jit/passes/lift_closures.cpp +++ b/torch/csrc/jit/passes/lift_closures.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { -// Closures are initially emitted as prim::Function nodes with a single block. +// Closures are initially emitted as prim::Closure nodes with a single block. // Here, we convert the block to a subgraph, adding all closed over variables // as a context tuple input to the closure node. // At this point the closure has already undergone conversion to SSA, @@ -58,7 +58,7 @@ void liftClosures(Block* block) { Node* n = *it; it++; switch (n->kind()) { - case prim::Function: { + case prim::Closure: { liftClosure(n); } break; default: { diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index 9d9ac0203b90..a06a3f94f3b1 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -78,6 +78,7 @@ const std::unordered_map& getOperatorAliasMap() { {aten::divide, aten::div}, {aten::divide_, aten::div_}, {aten::multiply, aten::mul}, {aten::multiply_, aten::mul_}, {aten::true_divide, aten::div}, {aten::true_divide_, aten::div_}, + {aten::row_stack, aten::vstack}, }; return alias_map; } diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index 41cd3cd2b8af..b854ae14387a 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -21,6 +21,7 @@ static constexpr char TupleOpen = '('; static constexpr char TupleClose = ')'; static constexpr char Variable = 'v'; static constexpr char String = 's'; +static constexpr char NoneType = 'n'; } // namespace D namespace { @@ -62,6 +63,8 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) { args.vars.push_back(var); args.desc.metadata.emplace_back(var); args.desc.structure.push_back(D::Variable); + } else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) { + args.desc.structure.push_back(D::NoneType); } else { std::string msg = "Only tuples, lists and Variables supported as JIT inputs/outputs. " @@ -136,6 +139,8 @@ py::object unflatten_rec( throw std::runtime_error("Not enough Variables given to unflatten"); auto str = *str_it++; return py::reinterpret_borrow(THPUtils_packString(str)); + } else if (type == D::NoneType) { + return py::reinterpret_borrow(py::none()); } else { if (var_it == var_it_end) throw std::runtime_error("Not enough Variables given to unflatten"); diff --git a/torch/csrc/jit/python/python_custom_class.cpp b/torch/csrc/jit/python/python_custom_class.cpp index 49c85c8c3c7f..9809b854e6ac 100644 --- a/torch/csrc/jit/python/python_custom_class.cpp +++ b/torch/csrc/jit/python/python_custom_class.cpp @@ -28,7 +28,10 @@ void initPythonCustomClassBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "ScriptClass") - .def("__call__", &ScriptClass::__call__); + .def("__call__", &ScriptClass::__call__) + .def_property_readonly("__doc__", [](const ScriptClass& self) { + return self.class_type_.type_->expect()->doc_string(); + }); // This function returns a ScriptClass that wraps the constructor // of the given class, specified by the qualified name passed in. diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index fc968237e4ba..20d0e6272a19 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -493,6 +493,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def("sourceRange", [](Node& n) { return n.sourceRange().str(); }) .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; }) + .def("inputsSize", [](Node& n) { return n.inputs().size(); }) .def("outputsSize", [](Node& n) { return n.outputs().size(); }) .NS(kind) .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); }) diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 15d151b761ef..7da82644150e 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -114,21 +115,20 @@ FunctionSchema PythonValue::getSchema( std::shared_ptr PythonValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - std::vector inputsWithSelf; + std::vector argsWithSelf; if (moduleSelf_) { - inputsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); + argsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); } - inputsWithSelf.insert(inputsWithSelf.end(), inputs_.begin(), inputs_.end()); - inputs_ = inputsWithSelf; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); - auto schema = getSchema(inputs_.size(), n_binders, loc); - auto inputs = toValues(*m.graph(), inputs_); + auto schema = getSchema(argsWithSelf.size(), n_binders, loc); + auto inputs = toValues(*m.graph(), argsWithSelf); MatchedSchema matched_schema = - matchSchema(schema, loc, *m.graph(), inputs_, attributes); + matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs); // If if a function is marked as dropped, // we throw an exception if it is invoked. @@ -234,9 +234,11 @@ SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) { SugaredValuePtr ModuleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) { - return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx); + return getSugaredDict(loc, m)->getModules()->getitem( + loc, m, idx, type_hint); } else if ( concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) { if (auto ivalue = toIValue(idx)) { @@ -252,6 +254,30 @@ SugaredValuePtr ModuleValue::getitem( } } throw ErrorReport(loc) << "Key Error, " << idx_str; + } else if (type_hint) { + // Check that all submodules comply with the type hint. + const auto& self_type = concreteType_->getJitType()->expect(); + for (size_t i = 0; i < self_type->numAttributes(); ++i) { + const auto& attr_type = self_type->getAttribute(i); + if (attr_type->is_module()) { + if (!attr_type->isSubtypeOf(type_hint)) { + auto loc = self_->node()->sourceRange(); + throw ErrorReport(loc) + << "Attribute " << self_type->getAttributeName(i) + << " is not of annotated type " << type_hint->annotation_str(); + } + } + } + + // Emit a prim::ModuleDictIndex operator. This is needed because it's + // difficult to construct a dict in the graph representing the ModuleDict + // and use aten::__getitem__ ops to index into it because any call to + // ModuleDict.setAttr would invalidate that emitted dict. + auto graph = m.graph(); + auto* getitem_node = + graph->insertNode(graph->create(prim::ModuleDictIndex, {self_, idx})); + getitem_node->output(0)->setType(type_hint); + return std::make_shared(getitem_node->output(0)); } throw ErrorReport(loc) << "Unable to extract string literal index. " @@ -652,8 +678,8 @@ void ModuleValue::setAttr( std::shared_ptr BooleanDispatchValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { c10::optional result; Graph& graph = *(caller.graph()); @@ -662,14 +688,14 @@ std::shared_ptr BooleanDispatchValue::call( auto arg_name = py::str(dispatched_fn_["arg_name"]); ErrorReport error(loc); - if (index < inputs.size()) { + if (index < args.size()) { // Dispatch flag is in arg list - result = constant_as(inputs.at(index).value(graph)); + result = constant_as(args.at(index).value(graph)); error << "Argument for boolean dispatch at position " << index << " was not constant"; - } else if (auto i = findInputWithName(arg_name, attributes)) { + } else if (auto i = findInputWithName(arg_name, kwargs)) { // Dispatch flag is in kwargs - result = constant_as(attributes[*i].value(graph)); + result = constant_as(kwargs[*i].value(graph)); error << "Keyword argument '" << arg_name << "' for boolean dispatch at position was not constant"; } else { @@ -688,28 +714,28 @@ std::shared_ptr BooleanDispatchValue::call( } else { value = toSugaredValue(dispatched_fn_["if_false"], caller, loc); } - return value->call(loc, caller, inputs, attributes, n_binders); + return value->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr PythonExceptionValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { Value* error_message = nullptr; - if (inputs.size() == 0) { + if (args.size() == 0) { error_message = insertConstant(*caller.graph(), "", loc); - } else if (inputs.size() == 1) { - error_message = inputs.at(0).value(*caller.graph()); + } else if (args.size() == 1) { + error_message = args.at(0).value(*caller.graph()); } else { std::vector message_values; - message_values.reserve(inputs.size() + attributes.size()); + message_values.reserve(args.size() + kwargs.size()); - for (auto inp : inputs) { + for (const auto& inp : args) { message_values.push_back(inp.value(*caller.graph())); } - for (auto kwarg_inp : attributes) { + for (const auto& kwarg_inp : kwargs) { message_values.push_back(kwarg_inp.value(*caller.graph())); } error_message = @@ -802,10 +828,10 @@ std::shared_ptr createSimpleEnumValue( std::shared_ptr PythonSliceClass::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { - if (!attributes.empty()) { + if (!kwargs.empty()) { throw ErrorReport(loc) << "Slice does not accept any keyword arguments"; } @@ -824,23 +850,23 @@ std::shared_ptr PythonSliceClass::call( Value* start; Value* stop; Value* step; - size_t n = inputs.size(); + size_t n = args.size(); // Slice's constructor signature is Slice(start=None, stop, step=None) if (n == 1) { // Case where only `stop` is specified. start = ValOr(nullptr, default_start); - stop = ValOr(inputs[0].value(graph), default_stop); + stop = ValOr(args[0].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 2) { // Case where `start` and `stop` are specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 3) { // Case where `start`, `stop` and `step` are all specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); - step = ValOr(inputs[2].value(graph), default_step); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); + step = ValOr(args[2].value(graph), default_step); } else { throw ErrorReport(loc) << "slice accepts exactly 1, 2 or 3 arguments, got: " << n; diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index ecb3c6da4ff4..12a5d87b063e 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -47,8 +47,8 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override; @@ -99,8 +99,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return toSimple(the_list_); } @@ -120,10 +120,10 @@ struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() || attributes.size()) { + if (args.size() || kwargs.size()) { throw ErrorReport(loc) << name_ << " method does not accept any arguments"; } @@ -175,11 +175,11 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return attr(loc, caller, "forward") - ->call(loc, caller, inputs, attributes, n_binders); + ->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr getSugaredDict( @@ -201,7 +201,8 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) override; + Value* idx, + TypePtr type_hint) override; private: Value* self_; @@ -268,8 +269,8 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -308,8 +309,8 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; @@ -324,8 +325,8 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 74e0e75362a6..a99f7469ac65 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1187,9 +1187,13 @@ void initJitScriptBindings(PyObject* module) { "name", [](const StrongFunctionPtr& self) { return self.function_->name(); }) .def_property_readonly( - "qualified_name", [](const StrongFunctionPtr& self) { + "qualified_name", + [](const StrongFunctionPtr& self) { return self.function_->qualname().qualifiedName(); - }); + }) + .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) { + return self.function_->doc_string(); + }); py::class_(m, "ScriptMethod", py::dynamic_attr()) .def( diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index e36208dfb19f..dc1ff95cf735 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -287,7 +287,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::MMBatchSide, prim::BroadcastSizes, prim::ChunkSizes, - prim::Function, + prim::Closure, prim::TupleUnpack, prim::TupleIndex, prim::TupleSlice, diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 3d7e8bb6f60e..d62697c1f9b6 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -40,8 +40,7 @@ C10_DEFINE_bool( namespace torch { namespace jit { -// TODO: keep the else clause for trial runs -#if defined(FBCODE_CAFFE2) || defined(C10_MOBILE) +#if defined(C10_MOBILE) static std::atomic executor_mode{true}; static std::atomic profiling_mode{false}; #else diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index ae974c063ef3..5bd85d20556a 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index f7dd9594347e..4706635a6a0c 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -568,6 +568,18 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + // This operator is generated inside the compiler for indexing into + // ModuleDict without a statically determinable key. Accordingly, + // self must be a ModuleType and the output must be an InterfaceType. + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "prim::ModuleDictIndex(Any self, str ind) -> Any"), + [](Stack* stack) { + IValue ind = pop(stack); + IValue module_dict = pop(stack); + push(stack, module_dict.toModule().attr(ind.toStringRef())); + }, + aliasAnalysisFromSchema()), Operator( "aten::dict() -> Dict(str, Tensor)", [](Stack* stack) { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 578586e9e9ff..6887be516e7b 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1369,8 +1369,8 @@ std::pair, Value*> extractClosure(Value* closure) { Value* context = closure->node()->inputs().at(1); TORCH_CHECK( - fn->node()->kind() == prim::Function, - "closure tuple must contain a prim::Function"); + fn->node()->kind() == prim::Closure, + "closure tuple must contain a prim::Closure"); return std::make_pair(fn->node()->g(attr::Subgraph), context); } diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 548c816bac6a..8eea619cd535 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -131,6 +131,18 @@ std::pair> getFunctionTuple( "use a dictionary type's key-value pair itmes or ", "a pytorch class (class Foo(torch.nn.Module))'s attributes.'"); } + } else if ( + input_type->kind() == TypeKind::ListType || + input_type->kind() == TypeKind::DictType) { + for (const TypePtr& element_type : input_type->containedTypes()) { + TORCH_CHECK( + element_type->kind() != TypeKind::ClassType, + "Returining a list or dictionary with pytorch class type ", + "is not supported in mobile module " + "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " + "Workaround: instead of using pytorch class as their element type, ", + "use a combination of list, dictionary, and single types."); + } } } } else { diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index e04339dacc22..9803829eb683 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -801,7 +801,7 @@ struct PythonPrintImpl { } level--; } break; - case prim::Function: { + case prim::Closure: { if (enforce_importable_) { throw ErrorReport(node->sourceRange()) << "closures are not exportable"; @@ -822,6 +822,15 @@ struct PythonPrintImpl { body_ << "):\n"; printBody(graph->block()); } break; + case prim::ModuleDictIndex: { + const auto dict = node->inputs().at(0); + const auto key = node->inputs().at(1); + const auto out = node->outputs().at(0); + assignValuesToTheirUniqueNames(out); + indent(); + body_ << useOf(out) << " : " << out->type()->annotation_str() << " = " + << useOf(dict) << "[" << useOf(key) << "]\n"; + } break; default: auto ss = std::make_shared(&source_range_stack_); printRHS(*ss, node); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index f9b933b78d65..68ec5c2e304a 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -935,13 +935,15 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::log: { - return computeOneOperand( - "aten_log", v, [](const ExprHandle& a) { return log(a); }); + return computeOneOperand("aten_log", v, [](const ExprHandle& a) { + return log(promoteIntegerToFloat(a)); + }); } break; case aten::log10: { - return computeOneOperand( - "aten_log10", v, [](const ExprHandle& a) { return log10(a); }); + return computeOneOperand("aten_log10", v, [](const ExprHandle& a) { + return log10(promoteIntegerToFloat(a)); + }); } break; case aten::log1p: { @@ -950,8 +952,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::log2: { - return computeOneOperand( - "aten_log2", v, [](const ExprHandle& a) { return log2(a); }); + return computeOneOperand("aten_log2", v, [](const ExprHandle& a) { + return log2(promoteIntegerToFloat(a)); + }); } break; case aten::exp: { @@ -1355,56 +1358,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } } -void TensorExprKernel::flattenTensors(BackendType backendType) { - if (backendType != BackendType::kCudaCodeGen && - backendType != BackendType::kBlockCodeGen) { - // We only need to flatten for GPU, for other backends just use the same - // tensors. - flatTensorOutputs_ = tensorOutputs_; - return; - } - - flatTensorOutputs_.resize(tensorOutputs_.size()); - for (size_t tensorIdx = 0; tensorIdx < tensorOutputs_.size(); tensorIdx++) { - Tensor* tensor = tensorOutputs_[tensorIdx]; - ExprHandle totalCount = ExprHandle(tensor->dim(0)); - for (int i = 1; i < tensor->ndim(); i++) { - const IntImm* totalCountImm = totalCount.AsNode(); - const IntImm* tensorDimImm = dynamic_cast(tensor->dim(i)); - if (totalCountImm && tensorDimImm) { - // TODO: switch to real constant folding when it is available. - totalCount = ExprHandle(totalCountImm->value() * tensorDimImm->value()); - } else { - totalCount = totalCount * ExprHandle(tensor->dim(i)); - } - } - // Flatten the index for GPU kernels. - // TODO: move this to fusing axis when it is ready. - Tensor* newOut = Compute( - tensor->buf()->name_hint() + "_flat", - {totalCount}, - [tensor](const VarHandle& index) -> ExprHandle { - std::vector dims; - ExprHandle value = index; - for (int i = tensor->ndim() - 1; i >= 0; i--) { - ExprHandle idx = value; - if (i > 0) { - idx = Mod::make(value, ExprHandle(tensor->dim(i))); - } - dims.push_back(idx); - value = value / ExprHandle(tensor->dim(i)); - } - std::reverse(dims.begin(), dims.end()); - return tensor->call(dims); - }); - flatTensorOutputs_[tensorIdx] = newOut; - } -} - Stmt* TensorExprKernel::generateStmt(BackendType backendType) { - flattenTensors(backendType); - - torch::jit::tensorexpr::LoopNest l(flatTensorOutputs_); + torch::jit::tensorexpr::LoopNest l(tensorOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); bool hasReduction = NodeFinder::find(l.root_stmt()).size() != 0; @@ -1417,12 +1372,12 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { l.computeInline(p.second->buf()); } if (backendType == kCudaCodeGen) { - for (size_t i = 0; i < flatTensorOutputs_.size(); i++) { - Tensor* tensor = flatTensorOutputs_[i]; - - // For every output tensor we've created a flattened 1D tensor - let's - // mark the original output tensor with computeInline - l.computeInline(tensorOutputs_[i]->buf()); + for (auto tensor : tensorOutputs_) { + std::vector loops = l.getLoopStmtsFor(tensor); + TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); + For* flattened = nullptr; + LoopNest::flatten(loops, &flattened); + assert(flattened); int loopLevels = getTECudaPointwiseLoopLevels(); const int kDefaultLoopLevels = 2; @@ -1437,8 +1392,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { if (blockSize < 0) { blockSize = kDefaultBlockSize; } - std::vector loops = l.getLoopStmtsFor(tensor); - l.splitWithMask(loops[0], blockSize, &outer, &inner); + l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); } else if (loopLevels == 3) { @@ -1451,8 +1405,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { const int kDefaultBlockSize = 256; blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount; blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize; - std::vector loops = l.getLoopStmtsFor(tensor); - l.splitWithMask(loops[0], blockCount * blockSize, &outer, &inner); + l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner); l.splitWithMask(inner, blockSize, &inner1, &inner2); l.setGPUBlockIndex(inner1, 0); l.setGPUThreadIndex(inner2, 0); @@ -1465,12 +1418,11 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { if (backendType == kBlockCodeGen) { auto block_analysis = std::make_unique(); - for (size_t i = 0; i < flatTensorOutputs_.size(); i++) { + for (auto tensor : tensorOutputs_) { const int default_fp16_blocksize = 16; const int default_uint8_blocksize = 32; int blockSize = default_fp16_blocksize; // We only handle looplevels == 2 for now - Tensor* tensor = flatTensorOutputs_[i]; // Run Block analysis to get multi dim buffer info auto root_stmt = l.root_stmt(); root_stmt->accept(block_analysis.get()); @@ -1478,12 +1430,16 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) { blockSize = default_uint8_blocksize; } - l.computeInline(l.getLoopBodyFor(tensorOutputs_[i])); - For* outer; - For* inner; + std::vector loops = l.getLoopStmtsFor(tensor); - TORCH_INTERNAL_ASSERT(loops.size() > 0, "loops should not be empty"); - l.splitWithMask(loops[0], blockSize, &outer, &inner); + TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); + For* flattened = nullptr; + LoopNest::flatten(loops, &flattened); + assert(flattened); + + For* outer = nullptr; + For* inner = nullptr; + l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); l.setBufferMap(outer, block_analysis->getBufferMap()); @@ -1531,7 +1487,7 @@ std::vector TensorExprKernel::prepareBufferArgs() { params.emplace_back(stride.var); } } - for (auto& o : flatTensorOutputs_) { + for (auto& o : tensorOutputs_) { params.emplace_back(o); } return params; diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 028b18112ab2..b21d6607fffb 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -125,7 +125,6 @@ class TORCH_API TensorExprKernel { Tensor* computeValue(const torch::jit::Value* v); - void flattenTensors(BackendType backendType); Stmt* generateStmt(BackendType backendType); std::vector prepareBufferArgs(); @@ -191,7 +190,6 @@ class TORCH_API TensorExprKernel { int64_t nInputs_ = 0; std::vector kernelArgs_; std::vector tensorOutputs_; - std::vector flatTensorOutputs_; std::unordered_map tensors_; std::unordered_map scalars_; std::unique_ptr codegen_; diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index abdaea147c00..55c1926b3541 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/custom_class.h b/torch/custom_class.h index 3805cfafc91a..571a584294db 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -58,14 +58,16 @@ class class_ { /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript - explicit class_(const std::string& namespaceName, const std::string& className) { + explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") { detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(className, "Class name"); qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; classTypePtr = at::ClassType::create( c10::QualifiedName(qualClassName), - std::weak_ptr()); + std::weak_ptr(), + /*is_module=*/false, + std::move(doc_string)); classTypePtr->addAttribute("capsule", at::CapsuleType::get()); c10::getCustomClassTypeMap().insert( @@ -81,7 +83,7 @@ class class_ { /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template - class_& def(detail::types) { // Used in combination with + class_& def(detail::types, std::string doc_string = "") { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); @@ -89,7 +91,7 @@ class class_ { object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; - defineMethod("__init__", std::move(func)); + defineMethod("__init__", std::move(func), std::move(doc_string)); return *this; } @@ -112,18 +114,18 @@ class class_ { /// // do something /// }) template - class_& def(std::string name, Func f) { + class_& def(std::string name, Func f, std::string doc_string = "") { auto wrapped_f = detail::wrap_func(std::move(f)); - defineMethod(std::move(name), std::move(wrapped_f)); + defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string)); return *this; } /// This is an unsafe method registration API added for adding custom JIT backend support via custom /// C++ classes. It is not for general purpose use. - class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema) { + class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(func)); + qualMethodName, std::move(schema), std::move(func), std::move(doc_string)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; @@ -228,7 +230,7 @@ class class_ { private: template - void defineMethod(std::string name, Func func) { + void defineMethod(std::string name, Func func, std::string doc_string = "") { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); @@ -241,7 +243,7 @@ class class_ { detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( - qualMethodName, std::move(schema), std::move(wrapped_func)); + qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 6b07e23c9476..9b5fd66bdb57 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -1,6 +1,7 @@ from enum import Enum from functools import partial +import torch.distributed as dist import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization from torch.nn.parallel import DistributedDataParallel @@ -38,6 +39,7 @@ def register_ddp_comm_hook( to the DDP model. User can specify the type of hook as an enum ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will be passed to the model. + Uses Python comm hook implementations. Example:: >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index f7718c4a20e0..dc541c932d11 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -2,7 +2,7 @@ # Default process group wide timeout, if applicable. # This only applies to the gloo and nccl backends -# (only if NCCL_BLOCKING_WAIT is set to 1). To make an attempt at -# backwards compatibility with THD, we use an extraordinarily high default -# timeout, given that THD did not have timeouts. +# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1). +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. default_pg_timeout = timedelta(minutes=30) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index e5a67edfae57..3090721c20ff 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -393,7 +393,20 @@ def init_process_group(backend, the process group. Default value equals 30 minutes. This is applicable for the ``gloo`` backend. For ``nccl``, this is applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` - is set to 1. + or ``NCCL_ASYNC_ERROR_HANDLING`` is set to 1. When + ``NCCL_BLOCKING_WAIT`` is set, this is the duration for which the + process will block and wait for collectives to complete before + throwing an exception. When ``NCCL_ASYNC_ERROR_HANDLING`` is set, + this is the duration after which collectives will be aborted + asynchronously and the process will crash. ``NCCL_BLOCKING_WAIT`` + will provide errors to the user which can be caught and handled, + but due to its blocking nature, it has a performance overhead. On + the other hand, ``NCCL_ASYNC_ERROR_HANDLING`` has little + performance overhead, but crashes the process on errors. This is + done since CUDA execution is async and it is no longer safe to + continue executing user code since failed async NCCL operations + might result in subsequent CUDA operations to run on corrupted + data. Only one of these two environment variables should be set. group_name (str, optional, deprecated): Group name. To enable ``backend == Backend.MPI``, PyTorch needs to be built from source @@ -1344,7 +1357,8 @@ def all_gather_object(object_list, obj, group=group.WORLD): object_list (list[Any]): Output list. It should be correctly sized as the size of the group for this collective and will contain the output. object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. If the calling rank is part of this group, the output of the @@ -1356,6 +1370,13 @@ def all_gather_object(object_list, obj, group=group.WORLD): collective since it does not provide an ``async_op`` handle and thus will be a blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + .. warning:: :func:`all_gather_object` uses ``pickle`` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data @@ -1367,16 +1388,19 @@ def all_gather_object(object_list, obj, group=group.WORLD): input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) - my_rank = get_rank() is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.cuda.current_device() + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=int, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -1386,8 +1410,8 @@ def all_gather_object(object_list, obj, group=group.WORLD): # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1414,7 +1438,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): collective and will contain the output. Must be ``None`` on non-dst ranks. (default is ``None``) dst (int, optional): Destination rank. (default is 0) - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. On the ``dst`` rank, ``object_gather_list`` will contain the @@ -1440,20 +1465,22 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): _validate_output_list_for_rank(my_rank, dst, object_gather_list) input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) + current_device = torch.device("cpu") is_nccl_backend = group_backend == Backend.NCCL if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + current_device = torch.cuda.current_device() + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=int, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] - # Allgather tensor sizes. An all-gather is needed here despite this being a gather, - # since each rank needs to broadcast a tensor of the same (maximal) size. + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. all_gather(object_size_list, local_size, group=group) max_object_size = max(object_size_list) # Resize tensor to max size across all ranks. @@ -1461,8 +1488,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): # Avoid populating output tensors if the result won't be gathered on this rank. if my_rank == dst: coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1495,15 +1522,23 @@ def broadcast_object_list(object_list, src, group=group.WORLD): Each object must be picklable. Only objects on the ``src`` rank will be broadcast, but each rank must provide lists of equal sizes. src (int): Source rank from which to broadcast ``object_list``. - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: ``None``. If rank is part of the group, ``object_list`` will contain the broadcasted objects from ``src`` rank. - .. note:: Note that this API differs slightly from the broadcast collective - since it does not provide an ``async_op`` handle and thus will be a - blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. .. warning:: :func:`broadcast_object_list` uses ``pickle`` module implicitly, which @@ -1524,8 +1559,14 @@ def broadcast_object_list(object_list, src, group=group.WORLD): group_backend = get_backend(group) is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - object_sizes_tensor = object_sizes_tensor.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.cuda.current_device() + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -1537,7 +1578,7 @@ def broadcast_object_list(object_list, src, group=group.WORLD): object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) if is_nccl_backend: - object_tensor = object_tensor.to(my_rank) + object_tensor = object_tensor.to(current_device) broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. offset = 0 diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 225cb4842bd1..e0aebf87ae84 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -17,6 +17,7 @@ import torch.distributed.rpc as rpc from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator +from torch.distributed.rpc.utils import _parse_remote_device from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle @@ -64,8 +65,7 @@ def _raise_not_supported(name): class _RemoteModule(nn.Module): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, @@ -100,8 +100,10 @@ def __init__( ``def forward_async(input: Tensor) -> Future[Tensor]:``. Arguments: - on (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -132,7 +134,7 @@ def __init__( >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", "cpu", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -155,18 +157,22 @@ def __init__( args = args if args is not None else () kwargs = kwargs if kwargs is not None else {} - self.on = on + self.on, self.device = _parse_remote_device(remote_device) if _module_interface_cls is not None: # Users reply on this field to know if this generated RemoteModule is TorchScript-able. self.is_scriptable = True # Instantiate template on remote side. - fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,)) + fut = rpc.rpc_async( + self.on, _instantiate_template, (_module_interface_cls,) + ) # Instantiate template on local side. - generated_module = instantiator.instantiate_scriptable_remote_module_template( - _module_interface_cls + generated_module = ( + instantiator.instantiate_scriptable_remote_module_template( + _module_interface_cls + ) ) generated_methods = generated_module._generated_methods @@ -178,9 +184,9 @@ def __init__( # Create the module on the remote side. self.module_rref = rpc.rpc_sync( - on, + self.on, _create_module, - (module_cls, args, kwargs, device, _module_interface_cls), + (module_cls, args, kwargs, self.device, _module_interface_cls), ) # Install generated methods. @@ -329,8 +335,10 @@ class RemoteModule(_RemoteModule): ``def forward_async(input: Tensor) -> Future[Tensor]:``. Arguments: - to (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -357,7 +365,7 @@ class RemoteModule(_RemoteModule): >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -374,10 +382,9 @@ class RemoteModule(_RemoteModule): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, ): - super().__init__(on, device, module_cls, args, kwargs) + super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 1888241e4da4..7cb99066b507 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -141,6 +141,34 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() +_thread_local_var = threading.local() + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + torch.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list @_require_initialized def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT): @@ -830,4 +858,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.shutdown() """ - return _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut diff --git a/torch/distributed/rpc/utils.py b/torch/distributed/rpc/utils.py new file mode 100644 index 000000000000..15924c4a72f0 --- /dev/null +++ b/torch/distributed/rpc/utils.py @@ -0,0 +1,37 @@ +def _parse_remote_device(remote_device: str): + r""" + Parses the remote device. + + Arguments: + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + + Returns: + A workername and a device. + """ + fields = remote_device.split("/") + if len(fields) == 2: + [on, device] = fields + elif len(fields) == 1: + on = fields[0] + device = "cpu" + else: + raise RuntimeError( + "Could not parse remote_device: {}. The valid format is '/'".format( + remote_device + ) + ) + + # Since the workername in the input remote device won't be validated until the created remote module is executed, + # only do some very basic sanity check on workername at the module creation time. + # As currently there is no regex to describe the format of workername, just check whether the workername is empty. + if not on: + raise RuntimeError( + "The workername in remote_device '{}' cannot be empty. The valid format is '/'".format( + remote_device + ) + ) + + return on, device diff --git a/torch/functional.py b/torch/functional.py index e31ec40d63d7..cb9be8117fa8 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1449,7 +1449,12 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` - - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)` + - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. + ``pivots`` stores all the intermediate transpositions of rows. + The final permutation ``perm`` could be reconstructed by + applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, + where ``perm`` is initially the identity permutation of :math:`m` elements + (essentially this is what :func:`torch.lu_unpack` is doing). - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of size :math:`(*)` where non-zero values indicate whether factorization for the matrix or diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 792a905432a5..c7fbd6fbf0ea 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -2,7 +2,7 @@ r''' **This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** -FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It +FX is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. diff --git a/torch/fx/experimental/Partitioner.py b/torch/fx/experimental/Partitioner.py index d01ed6075e2b..9c1bdaaa1335 100644 --- a/torch/fx/experimental/Partitioner.py +++ b/torch/fx/experimental/Partitioner.py @@ -1,18 +1,31 @@ from torch.fx.graph_module import GraphModule from torch.fx.node import Node, map_arg -from typing import Dict, List, Union, Set, NamedTuple, Tuple +from typing import Dict, List, Set, NamedTuple, Tuple import torch from torch.fx.experimental.subgraph_creation_example import split_module import operator -class DAGNode(NamedTuple): +class DAGNode(): """ DAGNode class maintains useful information for a partition (submodule). inputs(submodule node) and outputs(submodule node). """ - submodule: Node - input_nodes: List[Node] - output_nodes: List[Node] + def __init__( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_device_ids: List[int], + size_bytes: int + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: List[Node] = input_nodes + self.output_nodes: List[Node] = output_nodes + self.logical_device_ids: List[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) class DAG: """DAG class contains all the DAG nodes""" @@ -21,37 +34,46 @@ def __init__(self) -> None: def create_node( self, - submodule: Node, + submodule_node: Node, input_nodes: List[Node], - output_nodes: List[Node] + output_nodes: List[Node], + logical_devices: List[int], + size_bytes: int ) -> None: - node = DAGNode(submodule, input_nodes, output_nodes) + node = DAGNode(submodule_node, input_nodes, output_nodes, logical_devices, size_bytes) self.nodes.append(node) class Partition: """Partition class contains all the information about an individual partition. It also provides necessary methods for manipulation the partition. """ - def __init__(self, partition_id: int, fx_module: GraphModule) -> None: - self.graph_module = fx_module + def __init__(self, partition_id: int) -> None: self.nodes: Set[Node] = set() self.partition_id = partition_id self.parents: Set['Partition'] = set() self.children: Set['Partition'] = set() self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: List[int] = [] - def add_node(self, node: Node) -> None: - """Append a new node into the partition.""" - self.nodes.add(node) + def __str__(self): + return str(self.partition_id) - def add_parent(self, partition: 'Partition') -> None: - self.parents.add(partition) + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) - def add_child(self, partition: 'Partition') -> None: - self.children.add(partition) + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {'placeholder', 'get_attr'}: + self.nodes.add(n) + self.nodes.add(node) - def __str__(self): - return str(self.partition_id) class PartitionResult(NamedTuple): """NameTuple used for returning DAG and a new graph module @@ -61,45 +83,75 @@ class PartitionResult(NamedTuple): class Device(NamedTuple): name: str - available_mem_bytes: Union[float, int] + available_mem_bytes: int + logical_id: int + +class PartitionerConfig(NamedTuple): + devices: List[Device] + is_sparse_nn: bool = False + +def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError('node has no size_bytes attr') + # Don't forget the op node itself + size_bytes = getattr(node, 'size_bytes', None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError('node has no size_bytes attr') + return total_size_of_input_nodes + +def calculate_mem_bytes_needed(p1, p2): + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed class Partitioner: """A graph module may not fit into one device. Partitioner class helps cut one graph into subgraphs (partitions), so that each partition could fit into a different device. The main function of this class is self.partition_graph. - For self.partition_graph, first, it checks the size of the whole graph - and see if the whole graph can fit into one device. - If it does, it goes to self.find_single_partition - If the whole graph is even larger than the combined memory size of all devices, - a RuntimeError is raised. - If the whole graph cannot fit into one devices but - could be split into multiple devices, it goes to self.size_based_partition. - After the size_based_partition, it checks if the number of partitions exceeds - the number of devices. If it does, a RuntimeError is raised. - Otherwise, a DAG structure is returned + It will partition the graph based on the scheme specified in partition_config + A DAG structure is returned along with a new graph module with partitions as submodule nodes. """ def __init__(self) -> None: - self.partitions: Set[Partition] = set() + self.partitions: List[Partition] = [] + self.node_to_partitions: Dict[Node, int] = {} self.devices: List[Device] = [] - self.node_to_partitions: Dict[Node, List[int]] = {} - self.partition_to_used_mem_bytes: Dict[Partition, int] = {} def partition_graph( self, fx_module: GraphModule, torch_module: torch.nn.Module, - devices: List[Device] + partitioner_config: PartitionerConfig ) -> PartitionResult: """ - Given the fx module, torch module and devices, + Given the fx module, torch module and partitioner_config, find the partitions, do the partitions, and then return a DAG and a new fx module with submodule nodes (partitions) """ self.graph_module = fx_module - self.devices = devices self.torch_module = torch_module + self.devices = partitioner_config.devices if len(self.devices) == 0: raise RuntimeError('No devices') available_mem_bytes = self.devices[0].available_mem_bytes @@ -113,119 +165,183 @@ def partition_graph( if node.op == 'output': break total_size_of_graph += node.size_bytes.total_size - if total_size_of_graph <= available_mem_bytes: - self.find_single_partition() - elif total_size_of_graph > len(self.devices) * available_mem_bytes: + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + if total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition(total_size_of_graph) + elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): raise RuntimeError('Devices have no enough memory for the module') else: - if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): - raise RuntimeError('All devices must have same memory size!') - self.size_based_partition(available_mem_bytes) - # Check if enought devices are provided for all partitions - if len(self.partitions) > len(self.devices): - raise RuntimeError('Lack of Devices') + if partitioner_config.is_sparse_nn: + if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): + raise RuntimeError('All devices must have same memory size!') + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + else: + self.size_based_partition(available_mem_bytes) module_with_submodules = self.do_partition() # The DAG contains DAGNodes with info of each partition's input nodes, output nodes # and how partitions are connected. - dag = self.dump_partition_DAG(module_with_submodules) + dag = self.dump_dag(module_with_submodules) ret = PartitionResult(dag, module_with_submodules) return ret - def find_single_partition(self) -> None: + def find_single_partition(self, total_size_of_graph) -> None: """Only one partition (one graph on one device).""" partition_0 = self.create_partition() for node in self.graph_module.graph.nodes: if node.op == 'output': break - self.node_to_partitions[node] = [partition_0.partition_id] - partition_0.add_node(node) + self.node_to_partitions[node] = partition_0.partition_id + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [0] return - def size_based_partition(self, available_mem_bytes: Union[float, int]) -> None: - """This method partitions the graph based on memory size. - We assume all devices have the same memory size. + def size_based_partition(self, available_mem_bytes: int) -> None: + """This method is to partition the graph based on memory size. + It uses greedy approach. The result may not be the best. The basic idea is: - First, create a new partition. - Then traverse the graph through self.graph_module.graph.nodes - The traversal only focuses on op nodes - (call_function, call_module, call_method). - The placeholder nodes (placeholder) and constant nodes (get_attr) are skipped. - A placeholder (placeholder) or a constant (get_attr) - is added into a partition when it is a input node for a op node. - From one op node to another, check if a op node and its input nodes - can fit into the current partition. - If the current partition is full, create a new one - and continue traversing op nodes. - Then through self.combine_partition_based_on_size(), - partitions will be combined to keep - as less partitions as possible. - self.check_partition_dependecy checks if the combination of - partitions leads to a circular dependency + Step 1: + Find a device which has enough memory to fit the first node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + Try to combine those single node partitions with the non single node partitions + from Step 1 and Step 2. + If two partitions cannot be combined, but could fit into the same logical device, + Two partitions use the same logical device. """ - # Create the first partition + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device('', -1, -1) + for d in self.devices: + if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed: + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + 'is too large to fit any device') + occupied_devices.append(device) + return device + + def create_single_node_partition(node): + """Create a partition for a single node + """ + partition = self.create_partition() + total_size_needed = get_extra_size_of(node, set()) + partition.add_node(node) + partition.used_mem_bytes = total_size_needed + single_node_partitions.append(partition) + + # Track all single node partitions in Step 3 + single_node_partitions: List[Partition] = [] + # Track all non single node partitions in Step 1 and Step 2 + non_single_node_partitions: List[Partition] = [] + # Track partition and its left mem size + partition_to_left_mem_bytes: Dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: List[Device] = [] partition = self.create_partition() - # Track the used mem for the current partition - used_mem_bytes = 0 for node in self.graph_module.graph.nodes: if node.op in {'call_module', 'call_method', 'call_function'}: - # Find all its input nodes - input_nodes: Dict[Node, None] = {} - map_arg(node.args, lambda n: input_nodes.setdefault(n)) - map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) - # Calculate total size of related nodes - total_size_of_input_nodes = 0 - for n in input_nodes: - # Make sure this node hasn't been in this partition yet - if n not in partition.nodes: - size_bytes = getattr(n, 'size_bytes', None) - if size_bytes: - total_size_of_input_nodes += size_bytes.output_size - else: - raise RuntimeError('node has no size_bytes attr') - # Don't forget the op node itself - size_bytes = getattr(node, 'size_bytes', None) - if size_bytes: - total_size_of_input_nodes += size_bytes.total_size + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + # Update available mem for the current partitio + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into this partition + if partition_to_left_mem_bytes[partition] < total_size_of_input_nodes: + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device left, all the partitions before are non single node partitions + non_single_node_partitions = self.partitions[:] + # Create the first single node partition for the current node + create_single_node_partition(node) + continue + # Some devices are still left + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + partition.used_mem_bytes += total_size_of_input_nodes + # No device left, create single node partitions else: - raise RuntimeError('node has no size_bytes attr') - # The current node with its inputs cannot fit into the current partition - if used_mem_bytes + total_size_of_input_nodes > available_mem_bytes: - self.partition_to_used_mem_bytes[partition] = used_mem_bytes - partition = self.create_partition() - used_mem_bytes = 0 - # The current node may be too large to fit into a whole new partition - if total_size_of_input_nodes > available_mem_bytes: - raise RuntimeError(node.target + 'is too large to fit into a device') - # Add the current node into the current partition - partition.add_node(node) - # Add all input nodes if they are placeholders or constants - for n in input_nodes: - if (n not in partition.nodes) and (n.op in {'placeholder', 'get_attr'}): - partition.add_node(n) - used_mem_bytes = used_mem_bytes + total_size_of_input_nodes - # Update used mem mapping for the last partition - self.partition_to_used_mem_bytes[partition] = used_mem_bytes - # Find parent partitions and child partitions for each partition. + create_single_node_partition(node) self.set_parents_and_children() - # Combine small partitions - self.combine_partitions_based_on_size(available_mem_bytes) - # Reassign partition ids and update self.node_to_partitions. + # Check if having single node partitions + # If not, partition is done + if len(single_node_partitions) != 0: + # Going through all single node partitions, + # see if it can be combined with non single node partitions + # or at least fit into a logical device as a standaline partition + while single_node_partitions: + self.get_bfs_level_partition() + # Pick a single node partition + p1 = single_node_partitions.pop(0) + # Set up a flag + find_device = False + # Going through all non single partitions + # and find a device to fit p1 + for p2 in non_single_node_partitions: + # Calculate how many bytes are needed if combining p1 and p2 + mem_size_needed = calculate_mem_bytes_needed(p1, p2) + # Get the available size of p2 + available_mem_bytes = p2.used_mem_bytes + partition_to_left_mem_bytes[p2] + if mem_size_needed <= available_mem_bytes: + # Two partitions can be fit on the same device, + # check if combining them to be one partition + if abs(p1.bfs_level - p2.bfs_level) <= 1: + # Combining p1 and p2 into p0 + p0 = self.combine_two_partitions(p1, p2) + p0.logical_device_ids = p2.logical_device_ids + # Remove p2 from non_single_node_partitions + non_single_node_partitions.remove(p2) + # Add p0 to non_single_partitions + non_single_node_partitions.append(p0) + # Update partition_to_left_mem_bytes + partition_to_left_mem_bytes[p0] = available_mem_bytes - mem_size_needed + del partition_to_left_mem_bytes[p2] + else: + # Cannot combine two partitions, + # but two partitions can fit into p2's device + p1.logical_device_ids = p2.logical_device_ids + # Update partition_to_left_mem_bytes for p2 + partition_to_left_mem_bytes[p2] = available_mem_bytes - mem_size_needed + find_device = True + break + if not find_device: + raise RuntimeError('Lack of Devices') self.reorganize_partitions() return def do_partition(self) -> GraphModule: """Return a module with submodules (partitions).""" - for node in self.graph_module.graph.nodes: - if node.op == 'output': - break module_with_submodules = split_module( self.graph_module, self.torch_module, - lambda node: self.node_to_partitions[node][0] + lambda node: self.node_to_partitions[node] ) return module_with_submodules - def dump_partition_DAG(self, module_with_submodules: GraphModule) -> DAG: + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: dag = DAG() for node in module_with_submodules.graph.nodes: if node.op == 'output': @@ -245,18 +361,24 @@ def dump_partition_DAG(self, module_with_submodules: GraphModule) -> DAG: output_nodes = list(node.users) else: output_nodes = [node] - dag.create_node(node, list(input_nodes), output_nodes) + partition_id = int(node.name.rsplit('_', 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node(node, list(input_nodes), output_nodes, device_ids, size_bytes) return dag def create_partition(self) -> Partition: """Create a partition and append it to self.partitions.""" partition_id = len(self.partitions) - assert isinstance(self.graph_module, GraphModule) - partition = Partition(partition_id, self.graph_module) - self.partitions.add(partition) + partition = Partition(partition_id) + self.partitions.append(partition) return partition - def combine_partitions_based_on_size(self, available_mem_bytes) -> None: + def combine_partitions_based_on_size( + self, + partitions: List[Partition], + available_mem_bytes: int + ) -> None: """Combining small partitions together to keep as less partitions as possible. Here is an example of the algorithm to do this: Assume some partitions, we first sort them based on partiiton used memory size. @@ -274,67 +396,59 @@ def combine_partitions_based_on_size(self, available_mem_bytes) -> None: find_combination = True while find_combination: # Sort partitions based on memory size - sorted_partitions = sorted(self.partition_to_used_mem_bytes.items(), key=lambda item: item[1]) + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) # Mark bfs level self.get_bfs_level_partition() - find_combination = self.find_partition_to_combine_based_on_size(sorted_partitions, available_mem_bytes) + find_combination, partitions = \ + self.find_partition_to_combine_based_on_size( + sorted_partitions, + available_mem_bytes, + partitions + ) return def find_partition_to_combine_based_on_size( self, - sorted_partitions: List[Tuple[Partition, int]], - available_mem_bytes: int - ) -> bool: + sorted_partitions: List[Partition], + available_mem_bytes: int, + partitions: List[Partition] + ) -> Tuple[bool, List[Partition]]: """step 1 in self.combine_partition_based_on_size()""" + find_combination = False - smallest_partition = sorted_partitions.pop(0)[0] - left_mem = available_mem_bytes - self.partition_to_used_mem_bytes[smallest_partition] - for t in sorted_partitions[::-1]: - if t[1] <= left_mem and abs(smallest_partition.bfs_level - t[0].bfs_level) <= 1: - self.combine_two_partitions(t[0], smallest_partition) - find_combination = True - break - return find_combination + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + self.combine_two_partitions(p, smallest_partition) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions - def combine_two_partitions(self, partition_0: Partition, partition_1: Partition) -> None: + def combine_two_partitions( + self, + partition_0: Partition, + partition_1: Partition, + ) -> Partition: """Given two partitions, combine them into a new one - and remove the previous two partitions + and remove the previous two partitions from self.partitions """ partition = self.create_partition() partition.nodes = partition_0.nodes.union(partition_1.nodes) - partition.parents = partition_0.parents.union(partition_1.parents) - partition.children = partition_0.children.union(partition_1.children) - partition.bfs_level = max(partition_0.bfs_level, partition_1.bfs_level) - if partition_0 in partition.children: - partition.children.remove(partition_0) - if partition_0 in partition.parents: - partition.parents.remove(partition_0) - if partition_1 in partition.children: - partition.children.remove(partition_1) - if partition_1 in partition.parents: - partition.parents.remove(partition_1) - self.partition_to_used_mem_bytes[partition] = self.partition_to_used_mem_bytes[partition_0] + \ - self.partition_to_used_mem_bytes[partition_1] - del self.partition_to_used_mem_bytes[partition_0] - del self.partition_to_used_mem_bytes[partition_1] - # Replace partition_0 and partition_1 with the new partition in children and parents - for p in partition.parents: - if partition_0 in p.children: - p.children.remove(partition_0) - p.children.add(partition) - if partition_1 in p.children: - p.children.remove(partition_1) - p.children.add(partition) - for p in partition.children: - if partition_0 in p.parents: - p.parents.remove(partition_0) - p.parents.add(partition) - if partition_1 in p.parents: - p.parents.remove(partition_1) - p.parents.add(partition_1) + partition.recalculate_mem_size() self.partitions.remove(partition_0) self.partitions.remove(partition_1) - return + # reset parents and children for all partitions + for partition in self.partitions: + partition.parents = set() + partition.children = set() + self.set_parents_and_children() + return partition def set_parents_and_children(self) -> None: # Go through all nodes in a partition. @@ -351,10 +465,8 @@ def set_parents_and_children(self) -> None: # that partition is not the child of the current partition for p in self.partitions: if p != partition and n in p.nodes and node not in p.nodes: - if p not in partition.children: - partition.add_child(p) - if partition not in p.parents: - p.add_parent(partition) + partition.children.add(p) + p.parents.add(partition) return def reorganize_partitions(self) -> None: @@ -364,10 +476,7 @@ def reorganize_partitions(self) -> None: # Update self.node_to_partitions accordingly for partition in self.partitions: for node in partition.nodes: - if node not in self.node_to_partitions: - self.node_to_partitions[node] = [partition.partition_id] - else: - self.node_to_partitions[node].append(partition.partition_id) + self.node_to_partitions[node] = partition.partition_id return def get_bfs_level_partition(self) -> None: @@ -393,3 +502,98 @@ def get_bfs_level_partition(self) -> None: next_level = set() level += 1 return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It first traverse all the nodes and do the partitions based on memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + Different for size_based_partition, when traversing cross the boundary between + non-embedding nodes and embedding nodes, a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + def reset_partition_in_sparse_nn(partition, new_partition=True): + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == 'call_module': + submodule = self.graph_module + for atom in str(node.target).split('.'): + if not hasattr(submodule, atom): + raise RuntimeError(f'Module {submodule} has no attribute {atom}') + submodule = getattr(submodule, atom) + if 'Embedding' in str(submodule): + return True + return False + + # Track embedding partitons and non-embedding partitions separately + embedding_partitions: List[Partition] = [] + non_embedding_partitions: List[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {'call_module', 'call_method', 'call_function'}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes: + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError(node.target + 'is too large to fit into a device') + partition.add_node(node) + partition.used_mem_bytes += total_size_of_input_nodes + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for each partition + self.set_parents_and_children() + # Combining non-embedding partitions + self.combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + self.combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + self.reorganize_partitions() + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = 'Need ' + str(len(embedding_partitions)) + ' devices, but only ' \ + + str(len(self.devices)) + ' provided' + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if total_size_of_non_embedding_partitions + partition.used_mem_bytes > available_mem_bytes: + raise RuntimeError( + 'partition_' + + str(partition.partition_id) + + '(embedding partition) and non embedding partitions can not fit into one device' + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + return diff --git a/torch/fx/experimental/subgraph_creation_example.py b/torch/fx/experimental/subgraph_creation_example.py index dc473d53505d..930e8f35426e 100644 --- a/torch/fx/experimental/subgraph_creation_example.py +++ b/torch/fx/experimental/subgraph_creation_example.py @@ -115,8 +115,8 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona if not hasattr(target_attr, atom): raise RuntimeError(f'Operator target {node.target} not found!') target_attr = getattr(target_attr, atom) - partition.targets[node.target] = target_attr target = target_atoms[-1] + partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 983ba9a90ed7..d76dd2e6ba22 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -8,7 +8,7 @@ import re def _shadows_builtin_name(name: str) -> bool: - return name in builtins.__dict__ or name in keyword.kwlist + return name in builtins.__dict__ or name in keyword.kwlist or name in {'inf', 'nan'} def _is_magic(x: str) -> bool: return x.startswith('__') and x.endswith('__') @@ -271,11 +271,12 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb sanitized_name = node.name if '_' in node.name: base, maybe_idx = node.name.rsplit('_', 1) - try: - int(maybe_idx) - sanitized_name = base - except ValueError: - pass + if base != '': + try: + int(maybe_idx) + sanitized_name = base + except ValueError: + pass name = self._name(sanitized_name) return self.create_node(node.op, node.target, args, kwargs, name, node.type) @@ -380,7 +381,10 @@ def type_repr(o : Any): continue raise NotImplementedError(f'node: {node.op} {node.target}') - import_block = '\n'.join(f'import {name}' for name in sorted(modules_used)) + # repr() for inf and nan floating point values aren't parseable by + # python as literals. Explicitly import the names from the `math` module. + import_strs = [f'import {name}' for name in sorted(modules_used)] + import_block = '\n'.join(import_strs) code = ''.join(body) code = '\n'.join(' ' + line for line in code.split('\n')) + '\n' diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 6f72a29be184..24dde1ea13d4 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -4,6 +4,7 @@ from typing import Type, Dict, List, Any, Union from .graph import Graph import copy +import math # normal exec loses the source code, however we can patch # the linecache module to still recover it. @@ -28,7 +29,7 @@ def patched_getline(*args, **kwargs): linecache.getlines = patched_getline def _forward_from_src(src : str): - gbls: Dict[str, Any] = {} + gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan} exec_with_source(src, gbls) return gbls['forward'] diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index ab8a871adcf1..2865c97da22b 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -173,6 +173,21 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module)) orig_call = torch.nn.Module.__call__ + orig_getattr = torch.nn.Module.__getattr__ + + parameter_proxy_cache = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + def module_getattr_wrapper(mod, attr): + attr_val = orig_getattr(mod, attr) + if isinstance(attr_val, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if attr_val is p: + if n not in parameter_proxy_cache: + parameter_proxy_cache[n] = self.create_proxy('get_attr', n, (), {}) + return parameter_proxy_cache[n] + return attr_val def module_call_wrapper(mod, *args, **kwargs): def forward(*args, **kwargs): @@ -181,11 +196,14 @@ def forward(*args, **kwargs): return self.call_module(mod, forward, args, kwargs) try: + # Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427 + torch.nn.Module.__getattr__ = module_getattr_wrapper # type: ignore torch.nn.Module.__call__ = module_call_wrapper self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}, type_expr=fn.__annotations__.get('return', None)) finally: torch.nn.Module.__call__ = orig_call + torch.nn.Module.__getattr__ = orig_getattr # type: ignore return self.graph # Symbolic tracing API diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 19cce3a86945..d4f6f96c3da2 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -443,7 +443,7 @@ def graph(self): Returns a string representation of the internal graph for the ``forward`` method. See :ref:`interpreting-graphs` for details. """ - return self.forward.graph + return self._c._get_method("forward").graph @property def inlined_graph(self): diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 36880f22b36d..218ed460535c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -392,10 +392,16 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( LOG(INFO) << "[Rank " << rank_ << "] Wrote aborted communicator id to store: " << storeKey; } - LOG(INFO) << "[Rank " << rank_ - << "] Caught collective operation timeout for work: " - << (*this); - throw std::runtime_error("Operation timed out!"); + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::string exceptionMsg = c10::str("[Rank ", rank_, "] ", + "Caught collective operation timeout: ", + (*this), + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + throw std::runtime_error(exceptionMsg); } // Check for errors and throw appropriate exception. checkAndThrowException(); @@ -504,12 +510,18 @@ void ProcessGroupNCCL::abortTimedOutCollectives(std::unordered_set& // Check for Timeouts in the WorkNCCL Operations, and abort all // communicators accordingly. if (work.timedOut()) { - LOG(INFO) - << "[Rank " << rank_ - << "] Watchdog caught collective operation timeout for work: " - << work; + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - work.workStartTime_); + std::string exceptionMsg = c10::str("[Rank ", rank_, "] ", + "Watchdog caught collective operation timeout: ", + work, + " ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + LOG(INFO) << exceptionMsg; std::exception_ptr exception_ptr = std::make_exception_ptr( - std::runtime_error("NCCL Operation Timed Out")); + std::runtime_error(exceptionMsg)); work.setException(exception_ptr); for (const auto& ncclComm : work.ncclComms_) { ncclComm->ncclCommAbort(); diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 3b52616ee3fa..e6f589275a11 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include namespace c10d { @@ -302,8 +304,29 @@ class ProcessGroupNCCL : public ProcessGroup { // Do not free the underlying data storage of value_ before its // usage on futureNCCLCallbackStream_ finish. - TORCH_INTERNAL_ASSERT(record_stream_cb_); - record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap()); + if (record_stream_cb_ != nullptr) { + // If a Python communication hook is used, record_stream_cb_ will be + // set in torch/csrc/jit/python/pybind_utils.h, which allows Python + // dependency to be imported. + record_stream_cb_(value_, futureNCCLCallbackStream_->unwrap()); + } else { + // If a C++ communication hook is used, create and set a record stream + // callback. + TORCH_INTERNAL_ASSERT( + value_.isTensorList() || value_.isTensor(), + "the future value must be either a tensor list or a tensor."); + at::Tensor tensor; + if (value_.isTensorList()) { + const auto tensors = value_.toTensorVector(); + TORCH_INTERNAL_ASSERT( + tensors.size() == 1, "expected exactly 1 tensor"); + tensor = tensors[0]; + } else { + tensor = value_.toTensor(); + } + c10::cuda::CUDACachingAllocator::recordStream( + tensor.storage().data_ptr(), *futureNCCLCallbackStream_); + } // Use the dedicated callback stream to run callback. // Cannot move capture std::function in lambda, because it cannot deduce @@ -558,7 +581,8 @@ class ProcessGroupNCCL : public ProcessGroup { // This function iterates through the list of WorkNCCL objects in the // workList_ corresponding to incomplete collectives and then aborts NCCL // communicators associated with timed out collectives. - void abortTimedOutCollectives(std::unordered_set& abortedCommIds); + void abortTimedOutCollectives( + std::unordered_set& abortedCommIds); void workCleanupLoop(); diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 5e2b59c45c80..ad0badf5eed9 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -68,8 +68,7 @@ :attr:`dtype` before performing the operation, and the returned tensor's type will be :attr:`dtype`. If this argument is used in conjunction with the :attr:`out` argument, the output tensor's type must match this argument or a - RuntimeError will be raised. This argument is not currently supported for - :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + RuntimeError will be raised. Default: ``None`` Examples:: @@ -140,3 +139,48 @@ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) + +tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" +linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor + +Computes a tensor ``x`` such that ``tensordot(input, x, dims=x.ndim) = other``. +The resulting tensor ``x`` has the same shape as ``input[other.ndim:]``. + +Supports real-valued and, only on the CPU, complex-valued inputs. + +.. note:: If :attr:`input` does not satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])`` + after (optionally) moving the dimensions using :attr:`dims`, then a RuntimeError will be thrown. + +Args: + input (Tensor): "left-hand-side" tensor, it must satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])``. + other (Tensor): "right-hand-side" tensor of shape ``input.shape[other.ndim]``. + dims (Tuple[int]): dimensions of :attr:`input` to be moved before the computation. + Equivalent to calling ``input = movedim(input, dims, range(len(dims) - input.ndim, 0))``. + If None (default), no dimensions are moved. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> b = torch.randn(2 * 3, 4) + >>> x = torch.linalg.tensorsolve(a, b) + >>> x.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b) + True + + >>> a = torch.randn(6, 4, 4, 3, 2) + >>> b = torch.randn(4, 3, 2) + >>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2)) + >>> x.shape + torch.Size([6, 4]) + >>> a = a.permute(1, 3, 4, 0, 2) + >>> a.shape[b.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6) + True +""") diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 7280eab37caa..20a1d49619b0 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -69,7 +69,7 @@ def __init__(self, transposed: bool, output_padding: _size_1_t, groups: int, - bias: Optional[Tensor], + bias: bool, padding_mode: str) -> None: super(_ConvNd, self).__init__() if in_channels % groups != 0: diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 796861b0f4b3..29f07543f7a8 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1323,9 +1323,9 @@ class TripletMarginWithDistanceLoss(_Loss): >>> # Initialize embeddings >>> embedding = nn.Embedding(1000, 128) - >>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True) - >>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True) - >>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True) + >>> anchor_ids = torch.randint(0, 1000, (1,)) + >>> positive_ids = torch.randint(0, 1000, (1,)) + >>> negative_ids = torch.randint(0, 1000, (1,)) >>> anchor = embedding(anchor_ids) >>> positive = embedding(positive_ids) >>> negative = embedding(negative_ids) diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 734912684d8f..7a43fcc2ea2d 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -45,6 +45,10 @@ class MaxPool1d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the sliding window. This `link`_ has a nice visualization of the pooling parameters. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + Args: kernel_size: The size of the sliding window, must be > 0. stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. @@ -104,6 +108,10 @@ class MaxPool2d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -174,6 +182,10 @@ class MaxPool3d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension @@ -474,6 +486,10 @@ class AvgPool1d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be an ``int`` or a one-element tuple. @@ -537,6 +553,10 @@ class AvgPool2d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -614,6 +634,10 @@ class AvgPool3d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index f38d7fcaafc4..331d3885bd30 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -3,6 +3,7 @@ from torch.cuda import nccl from torch._utils import _take_tensors, _flatten_dense_tensors, \ _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index +from typing import List def broadcast(tensor, devices=None, *, out=None): @@ -121,7 +122,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): """ # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just # return `inputs`. - dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors) + dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors) output = [] ref_order = [] # process sparse ones first since they may have different sizes on different gpus diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 0fceb2137a3b..1789fea4816b 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -329,7 +329,7 @@ class DistributedDataParallel(Module): Example:: >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') - >>> net = torch.nn.DistributedDataParallel(model, pg) + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) """ def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, @@ -626,7 +626,7 @@ def no_sync(self): Example:: - >>> ddp = torch.nn.DistributedDataParallel(model, pg) + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads @@ -975,7 +975,7 @@ def join(self, divide_by_initial_world_size=True, enable=True): def _register_comm_hook(self, state: object, hook: callable): r""" - Register a communication hook which is an enhancement that provides a + Registers a communication hook which is an enhancement that provides a flexible hook to users where they can specify how DDP aggregates gradients across multiple workers. @@ -1060,6 +1060,40 @@ def _register_comm_hook(self, state: object, hook: callable): self._check_comm_hook(hook) dist._register_comm_hook(self.reducer, state, hook) + def _register_builtin_comm_hook( + self, comm_hook_type: dist.BuiltinCommHookType + ): + r""" + Registers a built-in communication hook that specifies how DDP + aggregates gradients across multiple workers. + The built-in hooks aim to provide efficient C++ implementations for certain hooks, + which might not be as efficient if implemented in Python using a Python communication hook. + + Arguments: + comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as + ALLREDUCE, FP16_COMPRESS, etc. + + .. warning :: + DDP communication hook can only be registered once and should be registered + before calling backward. + + .. warning :: + DDP communication hook does not support single-process multiple-device mode. + Gradbucket tensors should consist of only a single tensor. + + .. warning :: + DDP communication hook is experimental and subject to change. + + Example:: + Below is an example of a FP16 compression where gradients are + compressed into 16-bit floating-point numbers before allreduce, and + then decompressed after allreduce. + + >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS) + + """ + dist._register_builtin_comm_hook(self.reducer, comm_hook_type) + def _distributed_broadcast_coalesced( self, tensors, buffer_size, authoritative_rank=0 ): diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index a40a3e3fbcac..a064c72dda98 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -7,7 +7,7 @@ from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d from .conv import Conv1d, Conv2d, Conv3d -from .conv import ConvTranspose1d, ConvTranspose2d +from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag @@ -91,6 +91,7 @@ def from_float(mod): 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', + 'ConvTranspose3d', 'DeQuantize', 'ELU', 'Embedding', diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 31c914d2bf35..e9f4a4c701eb 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -606,9 +606,6 @@ class ConvTranspose2d(_ConvTransposeNd): For details on input arguments, parameters, and implementation see :class:`~torch.nn.ConvTranspose2d`. - .. note:: Currently only the QNNPACK engine is implemented. - Please, set the `torch.backends.quantized.engine = 'qnnpack'` - For special notes, please, see :class:`~torch.nn.quantized.Conv2d` Attributes: @@ -620,6 +617,7 @@ class ConvTranspose2d(_ConvTransposeNd): Examples:: + >>> # QNNPACK or FBGEMM as backend >>> torch.backends.quantized.engine = 'qnnpack' >>> # With square kernels and equal stride >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) @@ -684,3 +682,88 @@ def forward(self, input): raise ValueError("Input shape must be `(N, C, H, W)`!") return ops.quantized.conv_transpose2d( input, self._packed_params, self.scale, self.zero_point) + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> torch.backends.quantized.engine = 'fbgemm' + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose3d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros'): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super(ConvTranspose3d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode) + + def _get_name(self): + return 'QuantizedConvTranpose3d' + + def set_weight_bias(self, w, b): + # type: (torch.Tensor, Optional[torch.Tensor]) -> None + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 255c15b9da4a..421c23ebed6e 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -44,7 +44,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM model (torch.nn.Module): the model to be exported. args (tuple of arguments or torch.Tensor): the inputs to the model, e.g., such that ``model(*args)`` is a valid - invocation of the model. Any non-Tensor arguments will + invocation of the model. Any non-Tensor arguments (including None) will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in args. If args is a Tensor, this is equivalent diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 5a266a429965..f6acc4120dc2 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -581,6 +581,26 @@ def mm(g, self, other): return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) +def index(g, self, index): + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", self, index, operator_s="index") + + if sym_help._is_packed_list(index): + indices = sym_help._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not sym_help._is_none(index) and (index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): + from torch.onnx.symbolic_opset9 import nonzero + index = nonzero(g, index) + return g.op('GatherND', self, index) + from torch.onnx.symbolic_opset9 import index as index_opset9 + return index_opset9(g, self, index) + + def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, 'i') if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: diff --git a/torch/overrides.py b/torch/overrides.py index 281c02f6a6cf..88cf0b10868c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -279,6 +279,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.clip: lambda input, min=None, max=None, out=None: -1, torch.clamp_min: lambda input, min, out=None: -1, torch.clamp_max: lambda input, max, out=None: -1, + torch.column_stack: lambda tensors, out=None: -1, torch.clone: lambda input: -1, torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, @@ -388,6 +389,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.hstack: lambda tensors, out=None: -1, torch.hypot: lambda input, other, out=None: -1, torch.ifft: lambda input, signal_ndim, normalized=False: -1, + torch.igamma: lambda input, other, out=None: -1, torch.imag: lambda input, out=None: -1, torch.index_add: lambda input, dim, index, source: -1, torch.index_copy: lambda input, dim, index, source: -1, @@ -694,6 +696,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.roll: lambda input, shifts, dims=None: -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1, torch.round: lambda input, out=None: -1, + torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, torch.rsqrt: lambda input, out=None: -1, @@ -739,6 +742,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.tan: lambda input, out=None: -1, torch.tanh: lambda input, out=None: -1, torch.tensordot: lambda a, b, dims=2: -1, + torch.linalg.tensorsolve: lambda a, b, dims=None: -1, torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, torch.threshold: lambda input, threshold, value, inplace=False: -1, torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index ee506b6fc6a7..24e929b5fc8e 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -6,7 +6,7 @@ from .stubs import * from .quant_type import * from .quantize_jit import * -from .quantize_fx import * +# from .quantize_fx import * from .quantization_mappings import * from .fuser_method_mappings import * @@ -26,9 +26,9 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on TorchScript 'quantize_jit', 'quantize_dynamic_jit', # Top level API for graph mode quantization on GraphModule(torch.fx) - 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx - 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', # quantization type + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + 'QuantType', 'quant_type_to_str', # quantization type # custom module APIs 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', 'get_default_dynamic_quant_module_mappings', diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index dd2333883325..cacaad6bea0b 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -159,12 +159,12 @@ def forward(self, X): @torch.jit.export def extra_repr(self): - return 'fake_quant_enabled={}, observer_enabled={},\ - quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, \ - scale={}, zero_point={}'.format( - self.fake_quant_enabled, self.observer_enabled, - self.quant_min, self.quant_max, - self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) def _save_to_state_dict(self, destination, prefix, keep_vars): # We cannot currently register scalar values as buffers, so need to manually @@ -226,9 +226,19 @@ def forward(self, X): self.quant_max) return X + @torch.jit.export def calculate_qparams(self): return self.scale, self.zero_point + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \ + 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.scale, self.zero_point, self.dtype, + self.quant_min, self.quant_max, self.qscheme) + + default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 989f8937dd7c..56b375a02c00 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -12,13 +12,10 @@ from .fusion_patterns import * # noqa: F401 -import copy class Fuser: - def fuse(self, model, inplace=False, fuse_custom_config_dict=None): + def fuse(self, model, fuse_custom_config_dict=None): if fuse_custom_config_dict is None: fuse_custom_config_dict = {} - if not inplace: - model = copy.deepcopy(model) input_root = model input_graph = model.graph diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index d537630c2406..4c92192dc5be 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -9,15 +9,18 @@ # Fusion Patterns # --------------------- -@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) +@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) @register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) class ConvBNReLUFusion(): def __init__(self, quantizer, node): super().__init__() @@ -66,7 +69,8 @@ def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None): fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) - setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list)) + fused = fuser_method(*op_list) + setattr(quantizer.modules[conv_parent_name], conv_name, fused) # TODO: do we need to make sure bn is only used once? if self.bn_node is not None: @@ -77,8 +81,6 @@ def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None): @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear)) -@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm1d)) -@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py new file mode 100644 index 000000000000..7b55033ecba6 --- /dev/null +++ b/torch/quantization/fx/observed_module.py @@ -0,0 +1,52 @@ +import torch +import copy +from torch.fx import GraphModule + +class ObservedGraphModule(GraphModule): + + def get_preserved_attr_names(self): + return ['_activation_post_process_map', + '_patterns', + '_qconfig_map'] + + def __init__(self, root, graph): + preserved_attrs = dict() + for attr in self.get_preserved_attr_names(): + preserved_attrs[attr] = getattr(root, attr) + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule(fake_mod, self.graph) + +def mark_observed_module(module): + return ObservedGraphModule(module, module.graph) + +def is_observed_module(module): + return isinstance(module, ObservedGraphModule) + +class ObservedStandaloneGraphModule(ObservedGraphModule): + + def get_preserved_attr_names(self): + return ['_activation_post_process_map', + '_patterns', + '_qconfig_map', + '_standalone_module_observed_input_idxs', + '_output_is_observed'] + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule(fake_mod, self.graph) + +def mark_observed_standalone_module(module): + return ObservedStandaloneGraphModule(module, module.graph) + +def is_observed_standalone_module(module): + return isinstance(module, ObservedStandaloneGraphModule) diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index efc1cd492d1b..ccbd7cc2a2c4 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -14,10 +14,17 @@ def get_default_fusion_patterns(): return DEFAULT_FUSION_PATTERNS DEFAULT_QUANTIZATION_PATTERNS = OrderedDict() +# a map from pattern to activation_post_process(observer/fake_quant) consstructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_affine_fixed_qparam_fake_quant +DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP = dict() + # Register pattern for both static quantization and qat -def register_quant_pattern(pattern): +def register_quant_pattern(pattern, output_activation_post_process=None): def insert(fn): DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if output_activation_post_process is not None: + DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP[pattern] = output_activation_post_process return fn return insert @@ -25,6 +32,12 @@ def insert(fn): def get_default_quant_patterns(): return DEFAULT_QUANTIZATION_PATTERNS +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map(): + return DEFAULT_OUTPUT_ACTIVATION_POST_PROCESS_MAP + + # Example use of register pattern function: # @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) # class ConvBNReLUFusion(): @@ -62,6 +75,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): elif node.target is getattr: if node.args[1] != pattern[1]: return False + elif isinstance(self_match, str): + if node.op != 'call_method' or node.target != self_match: + return False elif node.target != self_match: return False diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 12787e4a87db..c588354a3192 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -4,6 +4,10 @@ ) import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +from torch.quantization import ( + default_affine_fixed_qparams_fake_quant, + default_symmetric_fixed_qparams_fake_quant, +) from ..quantization_mappings import ( get_static_quant_module_class, @@ -16,6 +20,7 @@ _parent_name, quantize_node, get_per_tensor_qparams, + get_swapped_custom_module_class, activation_is_statically_quantized, weight_is_quantized, weight_dtype, @@ -42,7 +47,8 @@ def __init__(self, quantizer, node): # this is an indicator of whether all the inputs are Node or not # since some op might be quantized differently depending on whether # all inputs are tensors or not, e.g. add/mul - self.all_nodes = True + self.num_node_args = len(node.args) + self.all_node_args = True @abstractmethod def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): @@ -67,18 +73,24 @@ def __init__(self, quantizer, node): node = node.args[0] assert node.op == 'call_function' and node.target in [operator.add, torch.add] self.add_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]]) + self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)]) def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if self.num_node_args == 1: # add scalar if self.relu_node is not None: op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add + + if isinstance(self.add_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( 'call_function', op, - load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs) + load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -92,6 +104,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs) +# TODO: merge with Add @register_quant_pattern(operator.mul) @register_quant_pattern(torch.mul) @register_quant_pattern((torch.nn.ReLU, operator.mul)) @@ -108,17 +121,23 @@ def __init__(self, quantizer, node): node = node.args[0] assert node.op == 'call_function' and node.target in [operator.mul, torch.mul] self.mul_node = node - self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]]) + self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)]) def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if self.num_node_args == 1: # mul scalar if self.relu_node is not None: op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul + + if isinstance(self.mul_node.args[0], Node): + quantized_index = 0 + else: + quantized_index = 1 + return quantizer.quantized_graph.create_node( - 'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs) + 'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -134,7 +153,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.cat) class Cat(QuantizeHandler): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if not self.all_node_args: return NotImplemented activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() @@ -176,6 +195,12 @@ def __init__(self, quantizer, node): def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): # TODO: debug option for conv module + qconfig = quantizer.qconfig_map[node.name] + activation_statically_quantized = activation_is_statically_quantized(qconfig) + # only static qunatization (for both ptq and qat) is supported for conv + if not activation_statically_quantized: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ @@ -434,7 +459,7 @@ class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): - if not self.all_nodes: + if not self.all_node_args: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' @@ -487,6 +512,22 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ return quantizer.quantized_graph.create_node( 'call_function', quantized_op, args, kwargs) +@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant) +@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant) +@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant) +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) + # these ops have quantized equivalents that do not need any extra information @register_quant_pattern(torch.nn.AdaptiveAvgPool1d) @register_quant_pattern(torch.nn.AdaptiveAvgPool2d) @@ -495,20 +536,16 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.nn.AvgPool2d) @register_quant_pattern(torch.nn.AvgPool3d) @register_quant_pattern(torch.nn.Dropout) -@register_quant_pattern(torch.nn.Hardsigmoid) @register_quant_pattern(torch.nn.Hardtanh) @register_quant_pattern(torch.nn.MaxPool1d) @register_quant_pattern(torch.nn.MaxPool2d) @register_quant_pattern(torch.nn.MaxPool3d) @register_quant_pattern(torch.nn.ReLU) @register_quant_pattern(torch.nn.ReLU6) -@register_quant_pattern(torch.nn.Sigmoid) -@register_quant_pattern(torch.nn.Tanh) @register_quant_pattern(torch.adaptive_avg_pool1d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d) @register_quant_pattern(torch.nn.functional.dropout) -@register_quant_pattern(torch.nn.functional.hardsigmoid) @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) @@ -528,11 +565,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern(torch.mean) @register_quant_pattern(torch.min) @register_quant_pattern(torch.repeat_interleave) -@register_quant_pattern(torch.sigmoid) @register_quant_pattern(torch.sort) @register_quant_pattern(torch.squeeze) @register_quant_pattern(torch.stack) -@register_quant_pattern(torch.tanh) @register_quant_pattern(torch.unsqueeze) @register_quant_pattern(operator.getitem) @register_quant_pattern(operator.floordiv) @@ -541,8 +576,6 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern('contiguous') @register_quant_pattern('detach') @register_quant_pattern('detach_') -@register_quant_pattern('hardsigmoid') -@register_quant_pattern('hardsigmoid_') @register_quant_pattern('mean') @register_quant_pattern('numel') @register_quant_pattern('permute') @@ -553,13 +586,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ @register_quant_pattern('reshape') @register_quant_pattern('resize_') @register_quant_pattern('shape') -@register_quant_pattern('sigmoid') -@register_quant_pattern('sigmoid_') @register_quant_pattern('size') @register_quant_pattern('squeeze') @register_quant_pattern('squeeze_') -@register_quant_pattern('tanh') -@register_quant_pattern('tanh_') @register_quant_pattern('transpose') @register_quant_pattern('unsqueeze') @register_quant_pattern('unsqueeze_') @@ -570,9 +599,9 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ # Default quantization handler, used for quantization of input and output # of quantizable objects (e.g. modules and functionals) -class DefaultQuant(QuantizeHandler): +class DefaultQuantizeHandler(QuantizeHandler): def convert(self, quantizer, node): - assert self.all_nodes + assert self.all_node_args root_module = quantizer.modules[''] return quantize_node( root_module, @@ -587,13 +616,14 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_ assert convert_custom_config_dict is not None custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) assert custom_module_class_mapping is not None + qconfig = quantizer.qconfig_map[node.name] observed_custom_module = quantizer.modules[node.target] - if node.name in quantizer.activation_post_process_map: + if activation_is_statically_quantized(qconfig): + assert node.name in quantizer.activation_post_process_map observed_custom_module.activation_post_process = \ quantizer.activation_post_process_map[node.name] - quantized_custom_module_class = custom_module_class_mapping.get(type(observed_custom_module), None) - assert quantized_custom_module_class is not None, "did not found quantized custom module for:" + \ - str(type(observed_custom_module)) + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig) quantized_custom_module = \ quantized_custom_module_class.from_observed(observed_custom_module) parent_name, name = _parent_name(node.target) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 3eebef4ff10a..732f2efdedfe 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -19,14 +19,20 @@ get_default_qat_module_mappings, ) -from ..quantize import _remove_qconfig +from ..quantize import ( + _remove_qconfig, + is_activation_post_process +) from .pattern_utils import ( is_match, get_default_quant_patterns, + get_default_output_activation_post_process_map, ) -from .standalone_module import ( +from .observed_module import ( + mark_observed_module, + is_observed_module, mark_observed_standalone_module, is_observed_standalone_module, ) @@ -36,12 +42,13 @@ from .utils import ( _parent_name, quantize_node, + get_custom_module_class_keys, + get_swapped_custom_module_class, activation_is_statically_quantized, ) from collections import OrderedDict import warnings -import copy import re from typing import Optional @@ -131,11 +138,6 @@ def assert_and_get_unique_device(module): device = next(iter(devices)) if len(devices) > 0 else None return device -def is_activation_post_process(module): - return (isinstance(module, torch.quantization.ObserverBase) or - isinstance(module, torch.quantization.FakeQuantize) or - isinstance(module, torch.quantization.FakeQuantizeBase)) - def is_submodule_of_fake_quant(name, module, named_modules): parent_name, _ = _parent_name(name) return is_activation_post_process(named_modules[parent_name]) @@ -309,7 +311,7 @@ def get_qconfig(module_name): self.modules[node.target].qconfig = module_qconfig self.qconfig_map[node.name] = module_qconfig - def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module): + def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -325,8 +327,6 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} - if not inplace: - model = copy.deepcopy(model) additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {}) self.patterns = get_default_quant_patterns().copy() for k, v in additional_quant_patterns.items(): @@ -347,13 +347,13 @@ def _prepare(self, model, qconfig_dict, inplace, prepare_custom_config_dict, is_ # match the patterns that will get quantized standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None) - custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class") matches = self._find_matches( - model.graph, self.modules, self.patterns, standalone_module_names, custom_module_class_mapping) + model.graph, self.modules, self.patterns, standalone_module_names, custom_module_classes) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, - # initialize an DefaultQuant object for each + # initialize an DefaultQuantizeHandler object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() @@ -383,7 +383,7 @@ def load_arg(a): continue prefix = node.name + '_activation_post_process_' - root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) + root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: @@ -403,8 +403,9 @@ def insert_observer(node, observer, device): if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) observed_custom_module_class = \ - custom_module_class_mapping[type(custom_module)] + get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) parent_name, name = _parent_name(node.target) @@ -430,9 +431,19 @@ def insert_observer(node, observer, device): if not activation_is_statically_quantized(qconfig): continue - # inserting observers for output of observed module, or mark the output - # as observed - if isinstance(obj, CopyNode): + if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training: + # we only insert fake quantize module in qat + activation_post_process_ctr = \ + get_default_output_activation_post_process_map().get(pattern, None) + assert activation_post_process_ctr is not None, \ + 'activation_post_process constructor not provided for ' + \ + 'pattern:' + str(pattern) + device = assert_and_get_unique_device(model) + insert_observer(node, activation_post_process_ctr(), device) + elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and + not model.training) or isinstance(obj, CopyNode): + # inserting observers for output of observed module, or mark the output + # as observed assert node.op in [ 'call_module', 'call_function', @@ -447,15 +458,23 @@ def is_observed(input_arg): # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) - elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: - if node.args[0].name in observed_node_names_set: + elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1: + input_node = matched_nodes[-1] # first node in the sequence + + def input_is_observed(arg): + return isinstance(arg, Node) and arg.name in observed_node_names_set + # This is checking if one of the argument of add/mul + # is an observed node + # If both of the inputs are number, + # we will not consider the output to be observed + if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]): observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) - elif qconfig is not None and obj.all_nodes: + elif qconfig is not None and obj.all_node_args: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers @@ -472,6 +491,7 @@ def is_observed(input_arg): else: env[node.name] = observed_graph.node_copy(node, load_arg) + # insert observer for output of the node if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module @@ -480,11 +500,11 @@ def is_observed(input_arg): continue get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) - _, qconfig, is_weight = quants[node.name] - if qconfig is not None: + _, activation_post_process_ctr = quants[node.name] + if activation_post_process_ctr is not None: # TODO: use insert_observer - new_observer = \ - qconfig.weight() if is_weight else qconfig.activation() + new_observer = activation_post_process_ctr() + # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: @@ -496,6 +516,7 @@ def is_observed(input_arg): model = GraphModule(model, observed_graph) self.save_state(model) + model = mark_observed_module(model) if is_standalone_module: assert result_node is not None assert isinstance(result_node.args[0], Node), \ @@ -513,19 +534,13 @@ def save_state(self, observed): observed._qconfig_map = self.qconfig_map def restore_state(self, observed): - err_msg = 'please make sure the model is produced by prepare' - assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \ - '_activation_post_process attribute ' + err_msg - assert hasattr(observed, '_patterns'), 'did not found ' + \ - '_patterns attribute ' + err_msg - assert hasattr(observed, '_qconfig_map'), 'did not found ' + \ - '_qconfig_map attribute ' + err_msg + assert is_observed_module(observed), 'incoming model must be produced by prepare_fx' self.activation_post_process_map = observed._activation_post_process_map self.patterns = observed._patterns self.qconfig_map = observed._qconfig_map - def prepare(self, model, qconfig_dict, inplace=False, prepare_custom_config_dict=None, is_standalone_module=False): - return self._prepare(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module) + def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False): + return self._prepare(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module) def _run_weight_observers(self, observed): r''' Extract the subgraph that produces the weight for dynamic quant @@ -546,7 +561,7 @@ def _run_weight_observers(self, observed): weight_observer_module() return - def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict=None, is_standalone_module=False): + def _convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. For standalone module: the inputs will be quantized by parent module, @@ -559,8 +574,6 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) - if not inplace: - model = copy.deepcopy(model) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) @@ -569,10 +582,11 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict model.eval().cpu() self.modules = dict(model.named_modules()) - custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None) + custom_module_classes = get_custom_module_class_keys( + convert_custom_config_dict, "observed_to_quantized_custom_module_class") matches = self._find_matches( model.graph, self.modules, self.patterns, - custom_module_class_mapping=custom_module_class_mapping) + custom_module_classes=custom_module_classes) quants = self._find_quants(model.graph, matches) @@ -669,7 +683,7 @@ def is_quantized(node): graph_output = map_arg(node.args[0], load_non_quantized) self.quantized_graph.output(graph_output) continue - root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) + root_node, matched, matched_pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None)) if root_node is node: if qconfig is None: result = self.quantized_graph.node_copy(node, load_non_quantized) @@ -682,7 +696,10 @@ def is_quantized(node): quantized = True # Need to get correct quantized/non-quantized state for the output of CopyNode - if isinstance(obj, CopyNode): + if type(obj) in [ + CopyNode, + FixedQParamsOpQuantizeHandler + ]: assert node.op in [ 'call_module', 'call_function', @@ -731,8 +748,8 @@ def is_quantized(node): # the node is quantized in parent module quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) else: - # dequantize inputs for the node that are not quantized - env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) + # copy quantized or non-quantized node + env[node.name] = self.quantized_graph.node_copy(node, load_x) # remove activation post process act_post_process_removed_graph = Graph() @@ -751,13 +768,7 @@ def load_arg(a): else: env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg) - module_dict = dict(model.named_modules()) - to_be_removed = [] - for name, module in model.named_modules(): - if is_activation_post_process(module) and not is_submodule_of_fake_quant(name, module, module_dict): - to_be_removed.append(name) - for n in to_be_removed: - delattr(model, n) + # removes qconfig and activation_post_process modules _remove_qconfig(model) model = GraphModule(model, act_post_process_removed_graph) return model @@ -810,15 +821,15 @@ def load_arg(a): quantized = GraphModule(quantized_root, folded_graph) return quantized - def convert(self, model, inplace=False, debug=False, convert_custom_config_dict=None, is_standalone_module=False): - quantized = self._convert(model, inplace, debug, convert_custom_config_dict, is_standalone_module) + def convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False): + quantized = self._convert(model, debug, convert_custom_config_dict, is_standalone_module) if not debug: quantized = self._fold_weight(quantized) return quantized def _find_matches( self, graph, modules, patterns, - standalone_module_names=None, custom_module_class_mapping=None): + standalone_module_names=None, custom_module_classes=None): """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. @@ -832,15 +843,15 @@ def _find_matches( Outputs a map of node_name -> - (node, matched_values, QuantizeHandler instance, qconfig) + (node, matched_values, matched_pattern, QuantizeHandler instance, qconfig) For example, { - 'relu_1': (relu_1, [relu_1], , QConfig(...)), + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, , QConfig(...)), ... } """ - if custom_module_class_mapping is None: - custom_module_class_mapping = {} + if custom_module_classes is None: + custom_module_classes = [] match_map = {} all_matched = set() @@ -862,7 +873,7 @@ def record_match(pattern, node, matched): matched = [] record_match(pattern, node, matched) for n in matched: - match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name]) + match_map[n.name] = (node, matched, pattern, value(self, node), self.qconfig_map[n.name]) all_matched.add(n.name) # break after finding the first match break @@ -870,10 +881,10 @@ def record_match(pattern, node, matched): # add custom module instances to the match result for node in graph.nodes: if node.op == 'call_module' and \ - type(self.modules[node.target]) in custom_module_class_mapping: + type(self.modules[node.target]) in custom_module_classes: custom_module_qconfig = self.qconfig_map[node.name] match_map[node.name] = ( - node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig) + node, [node], None, CustomModuleQuantizeHandler(self, node), custom_module_qconfig) def is_standalone_module(module_path): if standalone_module_names is None: @@ -888,7 +899,7 @@ def is_standalone_module(module_path): # add node to matched nodes custom_module_qconfig = self.qconfig_map[node.name] match_map[node.name] = ( - node, [node], StandaloneModuleQuantizeHandler(self, node), custom_module_qconfig) + node, [node], None, StandaloneModuleQuantizeHandler(self, node), custom_module_qconfig) return match_map @@ -902,16 +913,13 @@ def _find_quants(self, graph, matches): - matches: output of self._find_matches function Outputs a map of - node_name -> (QuantizeHandler instance (always DefaultQuant), qconfig) + node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler), + activation_post_process (observer/fake_quantize module) constructor) """ quants = {} - def visit(node, qconfig): + def visit(node, matched_pattern, qconfig): def visit_arg(arg): - # note: we have to measure quantization information - # even for nodes where we might not use it because it is already - # quantized. This is because each match has the option to - # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) is_weight = False if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT: for i, node_arg in enumerate(node.args): @@ -919,26 +927,39 @@ def visit_arg(arg): is_weight = True if qconfig is not None and \ (activation_is_statically_quantized(qconfig) or is_weight): - # overwrite previous quant config - quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight) + act_post_process_ctr = qconfig.weight if is_weight else qconfig.activation + quants[arg.name] = (DefaultQuantizeHandler(self, arg), qconfig, is_weight) + # overwrite the constructor from qconfig + act_post_process_ctr = \ + get_default_output_activation_post_process_map().get( + matched_pattern, + act_post_process_ctr) + # overwrite previous activation post process constructor if necessary + quants[arg.name] = (DefaultQuantizeHandler(self, arg), act_post_process_ctr) return visit_arg for node in graph.nodes: if node.name in matches: - root_node, matched, obj, qconfig = matches[node.name] + root_node, matched_nodes, matched_pattern, quantize_handler, qconfig = matches[node.name] # don't attach observer/fake_quant for CopyNode - if isinstance(obj, CopyNode): + if isinstance(quantize_handler, CopyNode): qconfig = None if root_node is node: - # matched[-1] is the first op in the sequence and - # matched[0] is the last op in the sequence + # matched_nodes[-1] is the first op in the sequence and + # matched_nodes[0] is the last op in the sequence # inputs - map_arg(matched[-1].args, visit(matched[-1], qconfig)) - map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) + # matched_pattern is set to None for inputs because + # we only want to select QuantizeHandler object based + # on pattern for output, inputs will always use + # DefaultQuantizeHandler + map_arg(matched_nodes[-1].args, visit(matched_nodes[-1], None, qconfig)) + map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1], None, qconfig)) + # output - if isinstance(obj, StandaloneModuleQuantizeHandler): - # we don't insert observer for output of custom - # module - continue - map_arg(matched[0], visit(None, qconfig)) + # we don't insert observer for output of standalone module + if not isinstance(quantize_handler, StandaloneModuleQuantizeHandler): + # passing in matched_pattern here so that we can customize + # activation_post_process constructor for output based on the pattern, e.g. + # for sigmoid op we'll use default_affine_fixed_qparam_fake_quant + map_arg(matched_nodes[0], visit(None, matched_pattern, qconfig)) return quants diff --git a/torch/quantization/fx/standalone_module.py b/torch/quantization/fx/standalone_module.py deleted file mode 100644 index 55aa8e21f98f..000000000000 --- a/torch/quantization/fx/standalone_module.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch -import copy -from torch.fx import GraphModule - -class ObservedStandaloneGraphModule(GraphModule): - _PRESERVED_ATTR_NAMES = [ - '_activation_post_process_map', - '_patterns', - '_qconfig_map', - '_standalone_module_observed_input_idxs', - '_output_is_observed'] - - def __init__(self, root, graph): - preserved_attrs = dict() - for attr in self._PRESERVED_ATTR_NAMES: - preserved_attrs[attr] = getattr(root, attr) - super().__init__(root, graph) - for attr in preserved_attrs: - setattr(self, attr, preserved_attrs[attr]) - - def __deepcopy__(self, memo): - fake_mod = torch.nn.Module() - fake_mod.__dict__ = copy.deepcopy(self.__dict__) - return ObservedStandaloneGraphModule(fake_mod, self.graph) - -def mark_observed_standalone_module(module): - return ObservedStandaloneGraphModule(module, module.graph) - -def is_observed_standalone_module(module): - return isinstance(module, ObservedStandaloneGraphModule) diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 98f94a0633a0..366970cec4c0 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -1,5 +1,6 @@ import re import torch +from ..quant_type import QuantType, quant_type_to_str # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -139,6 +140,55 @@ def get_next_qparams_idx(module, qparams): inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ["static", "dynamic", "weight_only"]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + +def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): + """ Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + quant_type_str = quant_type_to_str(quant_type) + class_mapping = custom_module_class_mapping.get(quant_type_str, {}) + assert type(custom_module) in class_mapping, "did not found corresponding observed " \ + "module class for {} in mapping: {}".format(type(custom_module), class_mapping) + return class_mapping[type(custom_module)] + def activation_is_statically_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be statically quantized or not @@ -158,6 +208,26 @@ def weight_is_quantized(qconfig): """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8] +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [torch.quint8, torch.qint8] + if weight.dtype in static_dtypes: + if activation.dtype in static_dtypes: + return QuantType.STATIC + elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + return QuantType.DYNAMIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if activation.dtype == torch.float: + return QuantType.WEIGHT_ONLY + + raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," + "weight({})".format(activation.dtype, weight.dtype)) + def get_linear_prepack_op_for_dtype(dtype): if dtype == torch.float16: return torch.ops.quantized.linear_prepack_fp16 diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index 212dec1fe28c..463d086b39b6 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -1,4 +1,3 @@ - import enum # Quantization type (dynamic quantization, static quantization). @@ -7,3 +6,14 @@ class QuantType(enum.IntEnum): DYNAMIC = 0 STATIC = 1 QAT = 2 + WEIGHT_ONLY = 3 + + +def quant_type_to_str(quant_type): + m = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", + } + return m[quant_type] diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 40fef784ce2d..82340a49309c 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -10,6 +10,10 @@ import torch.nn.qat as nnqat from .stubs import QuantStub, DeQuantStub +from .fake_quantize import ( + default_affine_fixed_qparams_fake_quant, + default_symmetric_fixed_qparams_fake_quant, +) # Default map for swapping float module to quantized ones DEFAULT_STATIC_QUANT_MODULE_MAPPINGS = { @@ -91,6 +95,13 @@ F.leaky_relu: torch._ops.ops.quantized.leaky_relu, } +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS = { + nn.Hardsigmoid: default_affine_fixed_qparams_fake_quant, + nn.Sigmoid: default_affine_fixed_qparams_fake_quant, + nn.Tanh: default_symmetric_fixed_qparams_fake_quant, +} + def get_default_static_quant_module_mappings(): ''' Get module mapping for post training static quantization ''' @@ -158,3 +169,15 @@ def get_quantized_operator(float_op): assert quantized_op is not None, \ 'Operator {} does not have corresponding quantized op'.format(str(float_op)) return quantized_op + +def get_default_special_act_post_process(module_cls): + r""" Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(module_cls, None) + +def has_special_act_post_process(module_cls): + return module_cls in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 9aa52d373ff6..f217c5ffb65c 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -9,14 +9,23 @@ import torch.nn.quantized as nnq import torch.nn.intrinsic.qat as nniqat -from .quantization_mappings import (get_default_dynamic_quant_module_mappings, - get_default_static_quant_module_mappings, - get_default_qat_module_mappings, - get_default_qconfig_propagation_list) +from .quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_static_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + has_special_act_post_process, + get_default_special_act_post_process, +) from .stubs import DeQuantStub, QuantWrapper from .qconfig import default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_dynamic_qconfig +def is_activation_post_process(module): + return (isinstance(module, torch.quantization.ObserverBase) or + isinstance(module, torch.quantization.FakeQuantize) or + isinstance(module, torch.quantization.FakeQuantizeBase)) + def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None, qconfig_parent=None, prefix=''): r"""This is a helper function for `propagate_qconfig_` @@ -107,8 +116,8 @@ def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=No ) device = next(iter(devices)) if len(devices) > 0 else None - def get_activation_post_process(qconfig, device): - activation = qconfig.activation() + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() if device is not None: activation.to(device) return activation @@ -116,13 +125,13 @@ def get_activation_post_process(qconfig, device): def needs_observation(m): return hasattr(m, 'qconfig') and m.qconfig is not None - def insert_activation_post_process(m): + def insert_activation_post_process(m, special_act_post_process=None): """ Adds an activation post process module and register a post hook that calls the module """ if needs_observation(m): # observer and hook will be gone after we swap the module - m.add_module('activation_post_process', get_activation_post_process(m.qconfig, device)) + m.add_module('activation_post_process', get_activation_post_process(m.qconfig, device, special_act_post_process)) # Register observer as the first entry in the hook list # All post forward hooks are preserved and will be executed after the observer before convert handle = register_activation_post_process_hook(m) @@ -130,8 +139,11 @@ def insert_activation_post_process(m): for name, child in module.named_children(): if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional: - if hasattr(child, 'qconfig') and child.qconfig is not None: + if needs_observation(child): child.activation_post_process = get_activation_post_process(child.qconfig, device) + elif has_special_act_post_process(type(child)): + special_act_post_process = get_default_special_act_post_process(type(child)) + insert_activation_post_process(child, special_act_post_process) elif non_leaf_module_list is not None and type(child) in non_leaf_module_list: insert_activation_post_process(child) elif needs_observation(child) and type(child) in custom_module_class_mapping: @@ -229,6 +241,19 @@ def prepare(model, inplace=False, allow_list=None, custom_module_class_mapping=custom_module_class_mapping) return model +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, 'activation_post_process') and \ + is_activation_post_process(module.activation_post_process): + delattr(module, 'activation_post_process') + + # remove activation_post_proceess hook + for handle_id, hook_fn in module._forward_hooks.items(): + if hook_fn is _observer_forward_hook: + module._forward_hooks.pop(handle_id) + +# TODO: rename to something more general def _remove_qconfig(module): r"""Clean up the qconfig left in the module so that new qconfig can be propagated. @@ -242,6 +267,8 @@ def _remove_qconfig(module): if hasattr(module, "qconfig"): del module.qconfig + _remove_activation_post_process(module) + def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Quantize the input float model with post training static quantization. diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 001c797ad6b1..93043559bf48 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -4,6 +4,7 @@ from .fx import Fuser # noqa: F401 from .fx import Quantizer # noqa: F401 from .fx.utils import graph_pretty_str # noqa: F401 +from .fx.utils import get_custom_module_class_keys # noqa: F401 def _check_is_graph_module(model): if not isinstance(model, GraphModule): @@ -26,7 +27,7 @@ def _swap_ff_with_fxff(model): del model._modules[name] model._modules[name] = torch.nn.quantized.FXFloatFunctional() -def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None): +def _fuse_fx(graph_module, fuse_custom_config_dict=None): r""" Internal helper function to fuse modules in preparation for quantization Args: @@ -34,7 +35,7 @@ def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None): """ _check_is_graph_module(graph_module) fuser = Fuser() - return fuser.fuse(graph_module, inplace, fuse_custom_config_dict) + return fuser.fuse(graph_module, fuse_custom_config_dict) class CustomTracer(Tracer): def __init__(self, skipped_module_names, skipped_module_classes): @@ -49,10 +50,10 @@ def is_leaf_module(self, m, module_qualified_name): type(m) in self.skipped_module_classes -def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, is_standalone_module=False): +def _prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False): r""" Internal helper function for prepare_fx Args: - `model`, `qconfig_dict`, `inplace` `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx` + `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx` `is_standalone_module`: a boolean flag indicates whether we are quantizing a standalone module or not, a standalone module is a submodule of the parent module that is not inlined in the @@ -74,21 +75,20 @@ def _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict=None, i # standalone module and custom module config are applied in top level module standalone_module_names = prepare_custom_config_dict.get('standalone_module_name', []) skipped_module_names += standalone_module_names - custom_module_config = prepare_custom_config_dict.get('float_to_observed_custom_module_class', {}) - custom_module_classes = list(custom_module_config.keys()) - skipped_module_classes += custom_module_classes + float_custom_module_classes = get_custom_module_class_keys( + prepare_custom_config_dict, "float_to_observed_custom_module_class") + skipped_module_classes += float_custom_module_classes tracer = CustomTracer(skipped_module_names, skipped_module_classes) graph_module = GraphModule(model, tracer.trace(model)) - graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict) + graph_module = _fuse_fx(graph_module, prepare_custom_config_dict) quantizer = Quantizer() return quantizer.prepare( graph_module, qconfig_dict, - inplace=True, prepare_custom_config_dict=prepare_custom_config_dict, is_standalone_module=is_standalone_module) -def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None): +def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dict=None): r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, @@ -103,16 +103,13 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu custom module is observed or not """ - torch._C._log_api_usage_once("quantization_api.quantize_fx._prepare_standalone_module_fx") - return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) - -def fuse_fx(model, inplace=False, fuse_custom_config_dict=None): +def fuse_fx(model, fuse_custom_config_dict=None): r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. Fusion rules are defined in torch.quantization.fx.fusion_pattern.py Args: `model`: a torch.nn.Module model - `inplace`: flag for whether we fuse modules inplace or out of place `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g. fuse_custom_config_dict = { "additional_fuser_method_mapping": { @@ -130,9 +127,9 @@ def fuse_fx(model, inplace=False, fuse_custom_config_dict=None): torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") assert not model.training, 'fuse_fx only works on models in eval mode' graph_module = torch.fx.symbolic_trace(model) - return _fuse_fx(graph_module, inplace, fuse_custom_config_dict) + return _fuse_fx(graph_module, fuse_custom_config_dict) -def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None): +def prepare_fx(model, qconfig_dict, prepare_custom_config_dict=None): r""" Prepare a model for post training static quantization Args: @@ -165,8 +162,6 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No # qconfig == None means fusion and quantization should be skipped for anything # matching the rule } - `inplace`: flag for carry out model transformations in-place, - the original module is mutated `prepare_custom_config_dict`: customization configuration dictionary for quantization tool: prepare_custom_config_dict = { @@ -178,8 +173,11 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No # user will manually define the corresponding observed # module class which has a from_float class method that converts # float custom module to observed custom module + # (only needed for static quantization) "float_to_observed_custom_module_class": { - CustomModule: ObservedCustomModule + "static": { + CustomModule: ObservedCustomModule + } }, # the qualified names for the submodule that are not symbolically traceable @@ -188,6 +186,7 @@ def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=No ], # the module classes that are not symbolically traceable + # we'll also put dynamic/weight_only custom module here "non_traceable_module_class": [ NonTraceableModule ], @@ -242,15 +241,13 @@ def calibrate(model, data_loader): torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") assert not model.training, 'prepare_fx only works for models in' + \ 'eval mode' - return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) -def prepare_qat_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None): +def prepare_qat_fx(model, qconfig_dict, prepare_custom_config_dict=None): r""" Prepare a model for quantization aware training Args: `model`: torch.nn.Module model, must be in train mode `qconfig_dict`: see :func:`~torch.quantization.prepare_fx` - `inplace`: flag for carry out model transformations in-place, - the original module is mutated `prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx` Return: @@ -279,21 +276,19 @@ def train_loop(model, train_data): torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") assert model.training, 'prepare_qat_fx only works for models in ' + \ 'train mode' - return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) -def _convert_fx(graph_module, inplace, debug, convert_custom_config_dict=None, is_standalone_module=False): +def _convert_fx(graph_module, debug, convert_custom_config_dict=None, is_standalone_module=False): """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` """ _check_is_graph_module(graph_module) quantizer = Quantizer() - return quantizer.convert(graph_module, inplace, debug, convert_custom_config_dict, is_standalone_module) + return quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module) -def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_dict=None): +def convert_fx(graph_module, debug=False, convert_custom_config_dict=None): r""" Convert a calibrated or trained model to a quantized model Args: `graph_module`: A prepared and calibrated/trained model (GraphModule) - `inplace`: flag for carry out model transformations in-place, - the original module is mutated `debug`: flag for producing a debug friendly model (preserve weight attribute) `convert_custom_config_dict`: dictionary for custom configurations for convert function: convert_custom_config_dict = { @@ -313,7 +308,15 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d # module class which has a from_observed class method that converts # observed custom module to quantized custom module "observed_to_quantized_custom_module_class": { - ObservedCustomModule: QuantizedCustomModule + "static": { + ObservedCustomModule: QuantizedCustomModule + }, + "dynamic": { + ObservedCustomModule: QuantizedCustomModule + }, + "weight_only": { + ObservedCustomModule: QuantizedCustomModule + } } } @@ -327,9 +330,9 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") - return _convert_fx(graph_module, inplace, debug, convert_custom_config_dict) + return _convert_fx(graph_module, debug, convert_custom_config_dict) -def _convert_standalone_module_fx(graph_module, inplace=False, debug=False, convert_custom_config_dict=None): +def _convert_standalone_module_fx(graph_module, debug=False, convert_custom_config_dict=None): r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model @@ -340,5 +343,4 @@ def _convert_standalone_module_fx(graph_module, inplace=False, debug=False, conv A quantized standalone module which accepts quantized input(if needed) and produces quantized output (if needed). """ - torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_standalone_module_fx") - return _convert_fx(graph_module, inplace, debug, convert_custom_config_dict, is_standalone_module=True) + return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True) diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 396d0718efbc..80120b019a99 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -6,6 +6,7 @@ import random import math from typing import cast, List, Optional, Tuple, Union +from .check_kernel_launches import check_cuda_kernel_launches, check_code_for_cuda_kernel_launches FileCheck = torch._C.FileCheck @@ -24,6 +25,9 @@ def is_integral(dtype: torch.dtype) -> bool: dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()] return dtype in dtypes and not dtype.is_floating_point +def is_quantized(dtype: torch.dtype) -> bool: + return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) + # Helper function that maps a flattened index back into the given shape # TODO: consider adding torch.unravel_index def _unravel_index(flat_index, shape): @@ -70,7 +74,11 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e debug_msg : Optional[str] # Integer (including bool) comparisons are identity comparisons # when rtol is zero and atol is less than one - if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool: + if ( + (is_integral(a.dtype) and rtol == 0 and atol < 1) + or a.dtype is torch.bool + or is_quantized(a.dtype) + ): if (a == b).all().item(): return (True, None) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 16144f3dc0d1..8bbd2bfdc944 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -274,8 +274,11 @@ def sample_inputs(self, device, dtype, requires_grad=False): test_inplace_grad=False), UnaryUfuncInfo('cos', ref=np.cos, - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), handles_large_floats=False, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -297,7 +300,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('log', ref=np.log, domain=(0, float('inf')), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 5e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -312,7 +318,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): ref=np.log10, domain=(0, float('inf')), decorators=(precisionOverride({torch.bfloat16: 5e-2}),), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -329,7 +338,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): UnaryUfuncInfo('log2', ref=np.log2, domain=(0, float('inf')), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-1}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -371,6 +383,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): )), UnaryUfuncInfo('tan', ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -574,6 +590,19 @@ def method_tests(): (True, [], ['aten::mul', 'aten::reciprocal'])), ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True, [], ['aten::mul', 'aten::reciprocal'])), + ('__rdiv__', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-1, (3.14j,), 'complex_constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('__rdiv__', uniform_scalar(1e-1 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_scalar_constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('div', (S, S, S), (torch.rand(S, S, S, dtype=torch.cdouble) + 0.1,), 'complex', (True,)), + ('div', (S, S, S), (torch.rand(S, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_rhs', (True,)), + ('div', (S, S), (torch.rand(S, S, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_lhs', (True,)), + ('div', (S, 1, S), (torch.rand(M, S, dtype=torch.cdouble) + 0.1,), 'complex_broadcast_all', (True,)), + ('div', (), (uniform_scalar(0.1j),), 'complex_scalar', (True,)), + ('div', (S, S, S), (uniform_scalar(0.1j),), 'complex_scalar_broadcast_rhs', (True,)), + ('div', (), (uniform_scalar(0.1j),), 'complex_scalar_broadcast_lhs', (True,)), + ('div', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-1, (3.14j,), 'complex_constant', (True,)), + ('div', uniform_scalar(1e-1j, requires_grad=True), (3.14j,), 'complex_scalar_constant', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs', (True,)), ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)), @@ -582,8 +611,11 @@ def method_tests(): ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)), ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs', (True,)), ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True,)), + ('pow', torch.rand(S, S, S, dtype=torch.cdouble) + 1e-3 * (1 + 1j), (3.14,), 'complex_constant', (True,)), ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')), ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)), + ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14,), 'complex_scalar_constant', (True,)), + ('pow', uniform_scalar(1e-3 * (1 + 1j), requires_grad=True), (3.14j,), 'complex_imaginary_exponent', (True,)), ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')), ('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]), ('transpose', (), (0, 0), 'scalar', (False,)), @@ -655,13 +687,12 @@ def method_tests(): ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), - # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition. - # ('log', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), - # ('log10', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), - # ('log2', torch.randn(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), - # ('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), + ('log', torch.randn(S, S, S, dtype=torch.cdouble) + 1e-2, NO_ARGS, 'complex', (True,)), + ('log', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), + ('log10', torch.randn(S, S, S, dtype=torch.cdouble) + 1e-2, NO_ARGS, 'complex', (True,)), + ('log10', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), + ('log2', torch.randn(S, S, S, dtype=torch.cdouble) + 1e-2, NO_ARGS, 'complex', (True,)), + ('log2', uniform_scalar(1e-2j, requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), ('tanh', (S, S, S), NO_ARGS, '', (True,)), ('tanh', (), NO_ARGS, 'scalar', (True,)), ('sigmoid', (S, S, S), NO_ARGS, '', (True,)), @@ -682,6 +713,8 @@ def method_tests(): ('complex', (S, S, S), ((S, S, S),), ''), ('abs', (S, S, S), NO_ARGS, '', (True,)), ('abs', (), NO_ARGS, 'scalar', (True,)), + ('angle', (S, S, S), NO_ARGS, '', (True,)), + ('angle', (), NO_ARGS, 'scalar', (True,)), ('clamp', (S, S, S), (0, 1), '', (True,)), ('clamp', (S, S, S), (None, 0.5), 'min', (True,)), ('clamp', (S, S, S), (0.5, None), 'max', (True,)), @@ -696,8 +729,7 @@ def method_tests(): ('cos', (S, S, S), NO_ARGS, '', (True,)), ('cos', (), NO_ARGS, 'scalar', (True,)), ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)), - # TODO(@anjali411): add the commented test back after updating the formula based on tensorflow definition. - # ('tan', (S, S, S), NO_ARGS, 'complex', (True,)), + ('tan', (S, S, S), NO_ARGS, 'complex', (True,)), ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), ('atan', (S, S, S), NO_ARGS, '', (True,)), @@ -709,9 +741,8 @@ def method_tests(): ('atan2', (S, 1, S), ((S, S),), 'broadcast_all'), ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)), ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)), - # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition. - # ('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)), - # ('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)), + ('reciprocal', torch.randn(S, S, S, dtype=torch.cdouble) + 0.1, NO_ARGS, 'complex', (True,)), + ('reciprocal', uniform_scalar(0.1j), NO_ARGS, 'complex_scalar', (True,)), ('round', (S, S, S), NO_ARGS, '', (True,)), ('round', (), NO_ARGS, 'scalar', (True,)), ('sign', (S, S, S), NO_ARGS), @@ -728,6 +759,8 @@ def method_tests(): ('deg2rad', (S, S, S), NO_ARGS), ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('rsqrt', torch.rand(S, S, S, dtype=torch.cfloat) + 1e-2, NO_ARGS, 'complex', (True,)), + ('rsqrt', uniform_scalar(1e-2 * (1 + 1j), requires_grad=True), NO_ARGS, 'complex_scalar', (True,)), ('frac', (S, S, S), NO_ARGS, '', (True,)), ('frac', (), NO_ARGS, 'scalar', (True,)), ('fmod', (S, S, S), (1.5,), '', (True,)), diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 440e59cf9174..1b2b4165b044 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -12,22 +12,23 @@ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic + get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType from torch.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, get_default_qconfig_propagation_list, get_default_qat_module_mappings, ) -# symbolic trace -from torch.fx import symbolic_trace - -# graph mode quantization based on fx -from torch.quantization import ( - QuantType, - prepare_fx, - prepare_qat_fx, - convert_fx, -) + +try: + # graph mode quantization based on fx + from torch.quantization.quantize_fx import ( + prepare_fx, + prepare_qat_fx, + convert_fx, + ) + HAS_FX = True +except ImportError: + HAS_FX = False import copy import io @@ -599,77 +600,77 @@ def printGraphModule(self, graph_module, print_str=True): print(str_to_print) return str_to_print - def checkGraphModeFxOp(self, model, inputs, quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - debug=False, - print_debug_info=False, - custom_qconfig=None): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node + if HAS_FX: + def checkGraphModeFxOp(self, model, inputs, quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + debug=False, + print_debug_info=False, + custom_qconfig=None): + """ Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + """ + # TODO: make img_data a single example instead of a list + if type(inputs) == list: + inputs = inputs[0] - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - """ - # TODO: make img_data a single example instead of a list - if type(inputs) == list: - inputs = inputs[0] - if custom_qconfig is None: if quant_type == QuantType.QAT: qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) + model.train() elif quant_type == QuantType.STATIC: qconfig = get_default_qconfig(torch.backends.quantized.engine) + model.eval() else: qconfig = default_dynamic_qconfig - else: - qconfig = custom_qconfig + model.eval() - if quant_type == QuantType.QAT: - model.train() - else: - model.eval() + # overwrite qconfig with custom_qconfig + if custom_qconfig is not None: + qconfig = custom_qconfig - original = symbolic_trace(model) - if quant_type == QuantType.QAT: - prepare = prepare_qat_fx - else: - prepare = prepare_fx - - qconfig_dict = {'': qconfig} - prepared = prepare(original, qconfig_dict) - if not quant_type == QuantType.DYNAMIC: - prepared(*inputs) - qgraph = convert_fx(prepared) - qgraph_debug = convert_fx(prepared, debug=True) - result = qgraph(*inputs) - result_debug = qgraph_debug(*inputs) - - qgraph_to_check = qgraph_debug if debug else qgraph - if print_debug_info: - print() - print('quant type:', quant_type) - print('origianl graph module:', type(model)) - self.printGraphModule(original) - print() - print('quantized graph module:', type(qgraph_to_check)) - self.printGraphModule(qgraph_to_check) - print() - self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) + if quant_type == QuantType.QAT: + prepare = prepare_qat_fx + else: + prepare = prepare_fx + + qconfig_dict = {'': qconfig} + prepared = prepare(model, qconfig_dict) + if not quant_type == QuantType.DYNAMIC: + prepared(*inputs) + prepared_copy = copy.deepcopy(prepared) + qgraph = convert_fx(prepared) + qgraph_debug = convert_fx(prepared_copy, debug=True) + result = qgraph(*inputs) + result_debug = qgraph_debug(*inputs) + + qgraph_to_check = qgraph_debug if debug else qgraph + if print_debug_info: + print() + print('quant type:', quant_type) + print('original model:', model) + print() + print('quantized model:', qgraph_to_check) + self.printGraphModule(qgraph_to_check) + print() + self.checkGraphModuleNodes( + qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 305b0fcb82bf..84768496b5ff 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -20,7 +20,7 @@ skip_if_lt_x_gpu, skip_if_rocm, ) -from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE +from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @@ -619,35 +619,38 @@ def test_ddp_dist_autograd_local_vs_remote(self): rank=self.rank, ) - remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 5, False) - ) - layer1 = nn.Linear(10, 5, False) - # Start with the same parameters for remote and local - layer1.weight = remote_layer1.module_rref.to_here().weight - - # Run local case. - layer2 = nn.Linear(5, 1) - inputs = torch.rand((10, 10)) - ddp_model = DistributedDataParallel(layer2) - loss = ddp_model(layer1(inputs)).sum() - loss.backward() - - # Run remote case. - with dist_autograd.context() as context_id: - loss = ddp_model(remote_layer1(inputs)).sum() - dist_autograd.backward(context_id, [loss]) - grads_dict = dist_autograd.get_gradients(context_id) - dist.barrier() - self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) - self.assertEqual( - layer1.weight.grad, - rpc.rpc_sync( - "worker0", - DdpComparisonTest.get_remote_grads, - args=(remote_layer1.module_rref, context_id), - ), + # Use two different remote device input string, w/ and w/o the default + # device string "cpu", respectively. + for remote_device in ["worker0/cpu", "worker0"]: + remote_layer1 = RemoteModule( + remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False) ) + layer1 = nn.Linear(10, 5, False) + # Start with the same parameters for remote and local + layer1.weight = remote_layer1.module_rref.to_here().weight + + # Run local case. + layer2 = nn.Linear(5, 1) + inputs = torch.rand((10, 10)) + ddp_model = DistributedDataParallel(layer2) + loss = ddp_model(layer1(inputs)).sum() + loss.backward() + + # Run remote case. + with dist_autograd.context() as context_id: + loss = ddp_model(remote_layer1(inputs)).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + dist.barrier() + self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) + self.assertEqual( + layer1.weight.grad, + rpc.rpc_sync( + "worker0", + DdpComparisonTest.get_remote_grads, + args=(remote_layer1.module_rref, context_id), + ), + ) @skip_if_lt_x_gpu(NUM_TRAINERS) @requires_nccl() @@ -667,7 +670,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): ) remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 7, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False) ) layer1 = nn.Linear(10, 7, False) # Start with the same parameters for remote and local @@ -677,7 +680,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank]) remote_layer3 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(5, 3, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False) ) layer3 = nn.Linear(5, 3, False) # Start with the same parameters for remote and local diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f304e37389b5..b12a72e43f35 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -3140,6 +3140,13 @@ def validate_global_samples(local_num_samples): @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_allgather_object(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + gather_objects = collectives_object_test_list output_gathered = [None for _ in range(dist.get_world_size())] dist.all_gather_object( @@ -3194,7 +3201,10 @@ class Bar: def test_nccl_gather_object_err(self): output_gathered = [None for _ in range(dist.get_world_size())] gather_on_rank = 0 + # Case where rank != GPU device. my_rank = dist.get_rank() + next_rank = (my_rank + 1) % dist.get_world_size() + torch.cuda.set_device(next_rank) with self.assertRaisesRegex( RuntimeError, "ProcessGroupNCCL does not support gather" ): @@ -3665,6 +3675,13 @@ def test_ddp_uneven_inputs_replicated_error(self): @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_broadcast_object_list(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + src_rank = 0 objects = collectives_object_test_list if self.rank == src_rank else [None for _ in collectives_object_test_list] @@ -3800,6 +3817,39 @@ def forward(self, x): else: ddp(inp).sum().backward() + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_shared_grad_acc_unused_params(self): + # When find_unused_parameters=True, ensure we mark unused parameters + # even if they share gradient accumulators. + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + # net1, bias, and net1.bias are all unused params. + self.net1 = nn.Linear(10, 5, bias=False) + self.bias = nn.Parameter(torch.zeros(5)) + # net1.bias and self.bias are names for the same underlying + # parameter, so they share the same grad acc. This caused + # the bug reported in https://github.com/pytorch/pytorch/issues/41324. + self.net1.bias = self.bias + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(x) + + torch.cuda.set_device(self.rank) + model = ToyModel().to(torch.cuda.current_device()) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], find_unused_parameters=True + ) + inp = torch.randn(20, 10, device=self.rank) + for i in range(6): + out = ddp_model(inp) + loss = out.sum() + loss.backward() + @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2) diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index da81b3b16e53..d6b3d816fe68 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -78,7 +78,7 @@ def world_size(self): # Override setting in RpcAgentTestFixture return 2 @staticmethod - def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): + def _create_remote_module_iter(remote_device, modes=None): if modes is None: modes = ModuleCreationMode.__members__.values() @@ -86,15 +86,12 @@ def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): kwargs = dict(first_kwarg=2) if ModuleCreationMode.MODULE_CTOR in modes: - remote_module = RemoteModule( - dst_worker_name, device, MyModule, args, kwargs - ) + remote_module = RemoteModule(remote_device, MyModule, args, kwargs) yield remote_module if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: remote_module = _RemoteModule( - dst_worker_name, - device, + remote_device, create_scripted_module, args, kwargs, @@ -108,6 +105,7 @@ def test_bad_module(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + remote_device = "{}/cpu".format(dst_worker_name) args = (1,) kwargs = dict(first_kwarg=2) @@ -115,13 +113,13 @@ def test_bad_module(self): ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) @dist_utils.dist_init def test_forward_async(self): @@ -227,7 +225,7 @@ def test_valid_device(self): dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( - dst_worker_name, device="cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] ): device = rpc.rpc_sync( dst_worker_name, remote_device, (remote_module.module_rref,) @@ -249,8 +247,7 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="foo", + "{}/foo".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -260,8 +257,7 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="cuda:100", + "{}/cuda:100".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -269,9 +265,8 @@ def test_invalid_devices(self): with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): list( self._create_remote_module_iter( - dst_worker_name, + "{}/cpu2".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], - device="cpu2", ) ) @@ -280,8 +275,48 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="cpu:2", + "{}/cpu:2".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): + list( + self._create_remote_module_iter( + "{}/".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "{}/cuda:0/cuda:1".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/' cannot be empty. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "/", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR], ) ) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 41006be6b39b..8dc25a6a56da 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -15,7 +15,7 @@ import torch.distributed.rpc as rpc import torch.distributed.autograd as dist_autograd from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info -from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler +from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler, _thread_local_var, _wait_all from torch.distributed.rpc.internal import ( PythonUDF, RPCExecMode, @@ -2856,6 +2856,58 @@ class TestPickler: torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler ) + @dist_init + def test_wait_all(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + self.assertTrue(len(_thread_local_var.future_list) == 1) + self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future)) + self.assertTrue(fut.done()) + self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_multiple_call(self): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + for i in range(20): + fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1)) + res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1)) + self.assertEqual(res, torch.ones(i, i) + 1) + self.assertEqual(fut.wait(), torch.ones(i, i) + 1) + self.assertTrue(len(_thread_local_var.future_list) == 20) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_timeout(self): + expected_error = self.get_timeout_error_regex() + with self.assertRaisesRegex(RuntimeError, expected_error): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + timeout = 0.1 # 100 ms + fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_user_func(self): + with self.assertRaises(ValueError): + with _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + fut = rpc.rpc_async(dst, raise_func) + self.assertFalse(hasattr(_thread_local_var, "future_list")) + + @dist_init + def test_wait_all_raise_in_body(self): + with self.assertRaises(ValueError): + with _wait_all(): + raise_func() + self.assertFalse(hasattr(_thread_local_var, "future_list")) + @dist_init def test_function_not_on_callee(self): # test that if a function does not exist on a callee, we don't crash, @@ -3648,6 +3700,41 @@ def test_cannot_infer_backend_from_options(self): rpc_backend_options=rpc_backend_options, ) + @dist_init + def test_local_rref_backward(self): + dst = worker_name((self.rank + 1) % self.world_size) + t1 = torch.rand(10, 10, requires_grad=True) + rref = rpc.RRef(t1.sum() + t1.sum()) + rref.backward() + expected_grad = torch.ones_like(t1) * 2 + self.assertEqual(expected_grad, t1.grad) + + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id) + self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1]) + + # Double backward. + with dist_autograd.context() as context_id: + t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1)) + rref = rpc.RRef(t2.sum()) + rref.backward(context_id, retain_graph=True) + rref.backward(context_id) + self.assertEqual(expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]) + + # Test errors. + with self.assertRaisesRegex(RuntimeError, "tensors does not require grad and does not have a grad_fn"): + rpc.RRef(torch.rand(10)).backward() + + with self.assertRaisesRegex(RuntimeError, "grad can be implicitly created only for scalar outputs"): + rpc.RRef(torch.rand(10, requires_grad=True)).backward() + + with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id: 100"): + rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100) + + with self.assertRaisesRegex(RuntimeError, "RRef should contain a tensor for .backward()"): + rpc.RRef("foo").backward() class ProcessGroupAgentRpcTest(RpcAgentTestFixture): diff --git a/torch/testing/check_kernel_launches.py b/torch/testing/check_kernel_launches.py new file mode 100644 index 000000000000..3385fcdf9618 --- /dev/null +++ b/torch/testing/check_kernel_launches.py @@ -0,0 +1,118 @@ +import os +import re +import sys + + +# Regular expression identifies a kernel launch indicator by +# finding something approximating the pattern ">>>(arguments);" +# It then requires that `TORCH_CUDA_KERNEL_LAUNCH_CHECK` be +# the next command. +# It allows a single backslash `\` between the end of the launch +# command and the beginning of the kernel check. This handles +# cases where the kernel launch is in a multiline preprocessor +# definition. +# +# There are various ways this can fail: +# * If the semicolon is in a string for some reason +# * If there's a triply-nested template +# But this should be sufficient to detect and fix most problem +# instances and can be refined before the test is made binding +kernel_launch_regex = re.compile(r""" + >>> # Identifies kernel launch + \s* # Maybe some whitespace (includes newlines) + \([^;]+\); # And then arguments in parens and semi-colon + (?! # Negative lookahead: we trigger if we don't find the launch guard + \s* # Maybe some whitespace (includes newlines) + \\? # 0 or 1 backslashes (for launches in preprocessor macros) + (?:[0-9]+: )? # Detects and ignores a line numbering, if present + \s* # Maybe some whitespace (includes newlines) + TORCH_CUDA_KERNEL_LAUNCH_CHECK\(\); # Kernel launch guard! + ) # End negative lookahead +""", flags=re.MULTILINE | re.VERBOSE) + + +def check_code_for_cuda_kernel_launches(code, filename=None): + """Checks code for CUDA kernel launches without cuda error checks. + + Args: + filename - Filename of file containing the code. Used only for display + purposes, so you can put anything here. + code - The code to check + + Returns: + The number of unsafe kernel launches in the code + """ + if filename is None: + filename = "##Python Function Call##" + + # We break the code apart and put it back together to add + # helpful line numberings for identifying problem areas + code = enumerate(code.split("\n")) # Split by line breaks + code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines + code = '\n'.join(code) # Put it back together + + results = kernel_launch_regex.findall(code) # Search for bad launches + for r in results: + print(f"Missing TORCH_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{r}", file=sys.stderr) + return len(results) + + +def check_file(filename): + """Checks a file for CUDA kernel launches without cuda error checks + + Args: + filename - File to check + + Returns: + The number of unsafe kernel launches in the file + """ + if not (filename.endswith(".cu") or filename.endswith(".cuh")): + return 0 + contents = open(filename, "r").read() + return check_code_for_cuda_kernel_launches(contents, filename) + + +def check_cuda_kernel_launches(): + """Checks all pytorch code for CUDA kernel launches without cuda error checks + + Returns: + The number of unsafe kernel launches in the codebase + """ + torch_dir = os.path.dirname(os.path.realpath(__file__)) + torch_dir = os.path.dirname(torch_dir) # Go up to parent torch + torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2 + + kernels_without_checks = 0 + files_without_checks = [] + for root, dirnames, filenames in os.walk(torch_dir): + # `$BASE/build` and `$BASE/torch/include` are generated + # so we don't want to flag their contents + if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"): + # Curtail search by modifying dirnames and filenames in place + # Yes, this is the way to do this, see `help(os.walk)` + dirnames[:] = [] + continue + + for x in filenames: + filename = os.path.join(root, x) + file_result = check_file(filename) + if file_result > 0: + kernels_without_checks += file_result + files_without_checks.append(filename) + + if kernels_without_checks > 0: + count_str = f"Found {kernels_without_checks} instances in " \ + f"{len(files_without_checks)} files where kernel " \ + "launches didn't have checks." + print(count_str, file=sys.stderr) + print("Files without checks:", file=sys.stderr) + for x in files_without_checks: + print(f"\t{x}", file=sys.stderr) + print(count_str, file=sys.stderr) + + return kernels_without_checks + + +if __name__ == "__main__": + unsafe_launches = check_cuda_kernel_launches() + sys.exit(0) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 4948e6e33099..afd654c6a85b 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -50,7 +50,7 @@ def _find_cuda_home() -> Optional[str]: if not os.path.exists(cuda_home): cuda_home = None if cuda_home and not torch.cuda.is_available(): - print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home)) + print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'") return cuda_home def _find_rocm_home() -> Optional[str]: @@ -72,7 +72,7 @@ def _find_rocm_home() -> Optional[str]: if not os.path.exists(rocm_home): rocm_home = None if rocm_home and torch.version.hip is None: - print("No ROCm runtime is found, using ROCM_HOME='{}'".format(rocm_home)) + print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'") return rocm_home @@ -275,13 +275,13 @@ def check_compiler_abi_compatibility(compiler) -> bool: version = (0, 0, 0) if match is None else match.groups() except Exception: _, error, _ = sys.exc_info() - warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error)) + warnings.warn(f'Error checking compiler version for {compiler}: {error}') return False if tuple(map(int, version)) >= minimum_required_version: return True - compiler = '{} {}'.format(compiler, ".".join(version)) + compiler = f'{compiler} {".".join(version)}' warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) return False @@ -364,6 +364,11 @@ def build_extensions(self) -> None: extension.extra_compile_args[ext] = [] self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') + # See note [Pybind11 ABI constants] + for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + val = getattr(torch._C, f"_PYBIND11_{name}") + if val is not None and not IS_WINDOWS: + self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') self._define_torch_extension_name(extension) self._add_gnu_cpp_abi_flag(extension) @@ -715,7 +720,7 @@ def _define_torch_extension_name(self, extension): # as the library name names = extension.name.split('.') name = names[-1] - define = '-DTORCH_EXTENSION_NAME={}'.format(name) + define = f'-DTORCH_EXTENSION_NAME={name}' self._add_compile_flag(extension, define) def _add_gnu_cpp_abi_flag(self, extension): @@ -1102,9 +1107,7 @@ def load_inline(name, # Make the function docstring the same as the function name. functions = dict((f, f) for f in functions) elif not isinstance(functions, dict): - raise ValueError( - "Expected 'functions' to be a list or dict, but was {}".format( - type(functions))) + raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") for function_name, docstring in functions.items(): if with_pytorch_error_handling: module_def.append( @@ -1170,9 +1173,9 @@ def _jit_compile(name, ) if version > 0: if version != old_version and verbose: - print('The input conditions for extension module {} have changed. '.format(name) + - 'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name)) - name = '{}_v{}'.format(name, version) + print(f'The input conditions for extension module {name} have changed. ' + + f'Bumping to version {version} and re-building as {name}_v{version}...') + name = f'{name}_v{version}' if version != old_version: baton = FileBaton(os.path.join(build_directory, 'lock')) @@ -1205,7 +1208,7 @@ def _jit_compile(name, baton.wait() elif verbose: print('No modifications detected for re-loaded extension ' - 'module {}, skipping build step...'.format(name)) + f'module {name}, skipping build step...') if verbose: print(f'Loading extension module {name}...') @@ -1292,11 +1295,11 @@ def _write_ninja_file_and_build_library( with_cuda=with_cuda) if verbose: - print('Building extension module {}...'.format(name)) + print(f'Building extension module {name}...') _run_ninja_build( build_directory, verbose, - error_prefix="Error building extension '{}'".format(name)) + error_prefix=f"Error building extension '{name}'") def is_ninja_available(): @@ -1342,10 +1345,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') extra_ldflags.append('torch.lib') extra_ldflags.append('torch_python.lib') - extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path)) - extra_ldflags.append('/LIBPATH:{}'.format(lib_path)) + extra_ldflags.append(f'/LIBPATH:{python_lib_path}') + extra_ldflags.append(f'/LIBPATH:{lib_path}') else: - extra_ldflags.append('-L{}'.format(lib_path)) + extra_ldflags.append(f'-L{lib_path}') extra_ldflags.append('-lc10') if with_cuda: extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') @@ -1359,19 +1362,18 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): if verbose: print('Detected CUDA files, patching ldflags') if IS_WINDOWS: - extra_ldflags.append('/LIBPATH:{}'.format( - _join_cuda_home('lib/x64'))) + extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}') extra_ldflags.append('cudart.lib') if CUDNN_HOME is not None: extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64')) elif not IS_HIP_EXTENSION: - extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64'))) + extra_ldflags.append(f'-L{_join_cuda_home("lib64")}') extra_ldflags.append('-lcudart') if CUDNN_HOME is not None: - extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64'))) + extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}') elif IS_HIP_EXTENSION: assert ROCM_VERSION is not None - extra_ldflags.append('-L{}'.format(_join_rocm_home('lib'))) + extra_ldflags.append(f'-L{_join_rocm_home("lib")}') extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc') return extra_ldflags @@ -1421,7 +1423,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # If not given, determine what's needed for the GPU that can be found if not _arch_list: capability = torch.cuda.get_device_capability() - arch_list = ['{}.{}'.format(capability[0], capability[1])] + arch_list = [f'{capability[0]}.{capability[1]}'] else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') @@ -1434,12 +1436,12 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: flags = [] for arch in arch_list: if arch not in valid_arch_strings: - raise ValueError("Unknown CUDA arch ({}) or GPU not supported".format(arch)) + raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") else: num = arch[0] + arch[2] - flags.append('-gencode=arch=compute_{},code=sm_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') if arch.endswith('+PTX'): - flags.append('-gencode=arch=compute_{},code=compute_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=compute_{num}') return list(set(flags)) @@ -1466,8 +1468,7 @@ def _get_build_directory(name: str, verbose: bool) -> str: root_extensions_directory = get_default_build_root() if verbose: - print('Using {} as PyTorch extensions root...'.format( - root_extensions_directory)) + print(f'Using {root_extensions_directory} as PyTorch extensions root...') build_directory = os.path.join(root_extensions_directory, name) if not os.path.exists(build_directory): @@ -1483,7 +1484,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]: max_jobs = os.environ.get('MAX_JOBS') if max_jobs is not None and max_jobs.isdigit(): if verbose: - print('Using envvar MAX_JOBS ({}) as the number of workers...'.format(max_jobs)) + print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...') return int(max_jobs) if verbose: print('Allowing ninja to set a default number of workers... ' @@ -1550,7 +1551,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # `error` is a CalledProcessError (which has an `ouput`) attribute, but # mypy thinks it's Optional[BaseException] and doesn't narrow if hasattr(error, 'output') and error.output: # type: ignore - message += ": {}".format(error.output.decode()) # type: ignore + message += f": {error.output.decode()}" # type: ignore raise RuntimeError(message) from e @@ -1592,10 +1593,28 @@ def _write_ninja_file_to_build_library(path, user_includes += system_includes system_includes.clear() - common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)] + common_cflags = [f'-DTORCH_EXTENSION_NAME={name}'] common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') - common_cflags += ['-I{}'.format(include) for include in user_includes] - common_cflags += ['-isystem {}'.format(include) for include in system_includes] + + # Note [Pybind11 ABI constants] + # + # Pybind11 before 2.4 used to build an ABI strings using the following pattern: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" + # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" + # + # This was done in order to further narrow down the chances of compiler ABI incompatibility + # that can cause a hard to debug segfaults. + # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties + # captured during PyTorch native library compilation in torch/csrc/Module.cpp + + for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + pval = getattr(torch._C, f"_PYBIND11_{pname}") + if pval is not None and not IS_WINDOWS: + common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') + + common_cflags += [f'-I{include}' for include in user_includes] + common_cflags += [f'-isystem {include}' for include in system_includes] common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] @@ -1639,9 +1658,9 @@ def object_file_path(source_file: str) -> str: if _is_cuda_file(source_file) and with_cuda: # Use a different object filename in case a C++ and CUDA file have # the same filename but different extension (.cpp vs. .cu). - target = '{}.cuda.o'.format(file_name) + target = f'{file_name}.cuda.o' else: - target = '{}.o'.format(file_name) + target = f'{file_name}.o' return target objects = [object_file_path(src) for src in sources] @@ -1657,7 +1676,7 @@ def object_file_path(source_file: str) -> str: ldflags = _nt_quote_args(ldflags) ext = 'pyd' if IS_WINDOWS else 'so' - library_target = '{}.{}'.format(name, ext) + library_target = f'{name}.{ext}' _write_ninja_file( path=path, @@ -1719,20 +1738,20 @@ def sanitize_flags(flags): # Version 1.3 is required for the `deps` directive. config = ['ninja_required_version = 1.3'] - config.append('cxx = {}'.format(compiler)) + config.append(f'cxx = {compiler}') if with_cuda: if IS_HIP_EXTENSION: nvcc = _join_rocm_home('bin', 'hipcc') else: nvcc = _join_cuda_home('bin', 'nvcc') - config.append('nvcc = {}'.format(nvcc)) + config.append(f'nvcc = {nvcc}') - flags = ['cflags = {}'.format(' '.join(cflags))] - flags.append('post_cflags = {}'.format(' '.join(post_cflags))) + flags = [f'cflags = {" ".join(cflags)}'] + flags.append(f'post_cflags = {" ".join(post_cflags)}') if with_cuda: - flags.append('cuda_cflags = {}'.format(' '.join(cuda_cflags))) - flags.append('cuda_post_cflags = {}'.format(' '.join(cuda_post_cflags))) - flags.append('ldflags = {}'.format(' '.join(ldflags))) + flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') + flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + flags.append(f'ldflags = {" ".join(ldflags)}') # Turn into absolute paths so we can emit them into the ninja build # file wherever it is. @@ -1765,7 +1784,7 @@ def sanitize_flags(flags): object_file = object_file.replace(':', '$:') source_file = source_file.replace(" ", "$ ") object_file = object_file.replace(" ", "$ ") - build.append('build {}: {} {}'.format(object_file, rule, source_file)) + build.append(f'build {object_file}: {rule} {source_file}') if library_target is not None: link_rule = ['rule link'] @@ -1776,15 +1795,13 @@ def sanitize_flags(flags): cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') else: raise RuntimeError("MSVC is required to load C++ extensions") - link_rule.append( - ' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format( - cl_path)) + link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') else: link_rule.append(' command = $cxx $in $ldflags -o $out') - link = ['build {}: link {}'.format(library_target, ' '.join(objects))] + link = [f'build {library_target}: link {" ".join(objects)}'] - default = ['default {}'.format(library_target)] + default = [f'default {library_target}'] else: link_rule, link, default = [], [], [] @@ -1796,7 +1813,7 @@ def sanitize_flags(flags): with open(path, 'w') as build_file: for block in blocks: lines = '\n'.join(block) - build_file.write('{}\n\n'.format(lines)) + build_file.write(f'{lines}\n\n') def _join_cuda_home(*paths) -> str: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 59c1827e1842..8d7726ebd129 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ in `./_utils/worker.py`. """ +import os import threading import itertools import warnings @@ -290,10 +291,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, self._iterator = None + self.check_worker_number_rationality() + def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: + self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self) @property @@ -399,6 +403,83 @@ def __len__(self) -> int: else: return len(self._index_sampler) + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + + suggested_max_worker_msg = (( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create.").format( + num_worker_suggest, + ("" if cpuset_checked else " (`cpuset` is not taken into account)")) + ) if num_worker_suggest is not None else ( + "DataLoader is not able to compute a suggested max number of worker in current system.") + + warn_msg = ( + "This DataLoader will create {} worker processes in total. {} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary.").format( + num_worker_created, + suggested_max_worker_msg) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, 'sched_getaffinity'): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satify mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: @@ -843,7 +924,7 @@ def _reset(self, loader, first_iter=False): # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. - # It does not mean that a worker is dead. In case of `_persistent_workers`, + # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index c910cab9aef8..7c45c10dd812 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -164,7 +164,7 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]): tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index):