diff --git a/.circleci/cimodel/data/dimensions.py b/.circleci/cimodel/data/dimensions.py index 1f83cd61b13c..57a6055c94ae 100644 --- a/.circleci/cimodel/data/dimensions.py +++ b/.circleci/cimodel/data/dimensions.py @@ -10,6 +10,7 @@ ROCM_VERSIONS = [ "3.7", "3.8", + "3.9", ] ROCM_VERSION_LABELS = ["rocm" + v for v in ROCM_VERSIONS] diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 12a40a17bed3..39f2a208a5ec 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -57,7 +57,7 @@ ]), ]), ]), - ("11.1", [ + ("11.0", [ ("3.8", [ X(True), ("libtorch", [ @@ -84,7 +84,11 @@ ("gcc", [ ("9", [ ("3.8", [ - ("coverage", [XImportant(True)]), + ("coverage", [ + (True, [ + ("shard_test", [XImportant(True)]), + ]), + ]), ]), ]), ]), diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 0c03fac487d6..75b0e8812e1b 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -272,6 +272,7 @@ def instantiate_configs(): compiler_version = fc.find_prop("compiler_version") is_xla = fc.find_prop("is_xla") or False is_asan = fc.find_prop("is_asan") or False + is_coverage = fc.find_prop("is_coverage") or False is_onnx = fc.find_prop("is_onnx") or False is_pure_torch = fc.find_prop("is_pure_torch") or False is_vulkan = fc.find_prop("is_vulkan") or False @@ -311,6 +312,10 @@ def instantiate_configs(): python_version = fc.find_prop("pyver") parms_list[0] = fc.find_prop("abbreviated_pyver") + if is_coverage: + parms_list_ignored_for_docker_image.append("coverage") + python_version = fc.find_prop("pyver") + if is_onnx: parms_list.append("onnx") python_version = fc.find_prop("pyver") @@ -325,7 +330,6 @@ def instantiate_configs(): is_important = fc.find_prop("is_important") or False parallel_backend = fc.find_prop("parallel_backend") or None build_only = fc.find_prop("build_only") or False - is_coverage = fc.find_prop("is_coverage") or False shard_test = fc.find_prop("shard_test") or False # TODO: fix pure_torch python test packaging issue. if shard_test: @@ -333,9 +337,6 @@ def instantiate_configs(): restrict_phases.extend(["test1", "test2"]) if build_only or is_pure_torch: restrict_phases = ["build"] - if is_coverage and restrict_phases is None: - restrict_phases = ["build", "coverage_test"] - gpu_resource = None if cuda_version and cuda_version != "10": diff --git a/.circleci/cimodel/data/simple/docker_definitions.py b/.circleci/cimodel/data/simple/docker_definitions.py index a216b084de59..fa77b2555073 100644 --- a/.circleci/cimodel/data/simple/docker_definitions.py +++ b/.circleci/cimodel/data/simple/docker_definitions.py @@ -43,7 +43,7 @@ def get_workflow_jobs(): parameters = OrderedDict({ "name": quote(f"docker-{image_name}"), "image_name": quote(image_name), - }) + }) if image_name == "pytorch-linux-xenial-py3.6-gcc5.4": # pushing documentation on tags requires CircleCI to also # build all the dependencies on tags, including this docker image diff --git a/.circleci/config.yml b/.circleci/config.yml index 38ec1ff25a3a..3095f2840772 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -655,9 +655,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test python3 -mpip install codecov python3 -mcodecov fi @@ -2179,6 +2181,39 @@ workflows: only: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ docker_image: "pytorch/manylinux-rocm:3.8" + - binary_linux_build: + name: binary_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly_build + build_environment: "manywheel 3.6m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-rocm:3.9" + - binary_linux_build: + name: binary_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly_build + build_environment: "manywheel 3.7m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-rocm:3.9" + - binary_linux_build: + name: binary_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly_build + build_environment: "manywheel 3.8m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + docker_image: "pytorch/manylinux-rocm:3.9" - binary_linux_build: name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_build build_environment: "conda 3.6 cpu devtoolset7" @@ -3523,6 +3558,51 @@ workflows: docker_image: "pytorch/manylinux-rocm:3.8" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly_test + build_environment: "manywheel 3.6m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly_test + build_environment: "manywheel 3.7m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - binary_linux_test: + name: binary_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly_test + build_environment: "manywheel 3.8m rocm3.9 devtoolset7" + filters: + branches: + only: + - /.*/ + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + requires: + - binary_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly_build + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - binary_linux_test: name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_test build_environment: "conda 3.6 cpu devtoolset7" @@ -5068,6 +5148,48 @@ workflows: - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ package_type: manywheel upload_subfolder: rocm3.8 + - binary_upload: + name: binary_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: rocm3.9 + - binary_upload: + name: binary_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: rocm3.9 + - binary_upload: + name: binary_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly_upload + context: org-member + requires: + - binary_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly_test + filters: + branches: + only: + - nightly + tags: + only: + - /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ + package_type: manywheel + upload_subfolder: rocm3.9 - binary_upload: name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_upload context: org-member @@ -6806,37 +6928,37 @@ workflows: build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" - pytorch_linux_build: - name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build + name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build requires: - - "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + - "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" filters: branches: only: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test + name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test requires: - - pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build + - pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build filters: branches: only: - master - /ci-all\/.*/ - /release\/.*/ - build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: - name: pytorch_libtorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build + name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build requires: - - "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - build_environment: "pytorch-libtorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" + - "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" + build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7" - pytorch_linux_build: name: pytorch_linux_bionic_py3_6_clang9_build requires: @@ -6877,16 +6999,23 @@ workflows: docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" resource_class: large - pytorch_linux_build: - name: pytorch_linux_bionic_py3_8_gcc9_build + name: pytorch_linux_bionic_py3_8_gcc9_coverage_build requires: - "docker-pytorch-linux-bionic-py3.8-gcc9" - build_environment: "pytorch-linux-bionic-py3.8-gcc9-build" + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" + - pytorch_linux_test: + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test1 + requires: + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test1" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" + resource_class: large - pytorch_linux_test: - name: pytorch_linux_bionic_py3_8_gcc9_coverage_test + name: pytorch_linux_bionic_py3_8_gcc9_coverage_test2 requires: - - pytorch_linux_bionic_py3_8_gcc9_build - build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage_test" + - pytorch_linux_bionic_py3_8_gcc9_coverage_build + build_environment: "pytorch-linux-bionic-py3.8-gcc9-coverage-test2" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9" resource_class: large - pytorch_linux_build: @@ -7661,6 +7790,42 @@ workflows: docker_image: "pytorch/manylinux-rocm:3.8" use_cuda_docker_runtime: "1" resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_6m_rocm3_9_devtoolset7_nightly + build_environment: "manywheel 3.6m rocm3.9 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_7m_rocm3_9_devtoolset7_nightly + build_environment: "manywheel 3.7m rocm3.9 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - smoke_linux_test: + name: smoke_linux_manywheel_3_8m_rocm3_9_devtoolset7_nightly + build_environment: "manywheel 3.8m rocm3.9 devtoolset7" + requires: + - update_s3_htmls + filters: + branches: + only: + - postnightly + docker_image: "pytorch/manylinux-rocm:3.9" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium - smoke_linux_test: name: smoke_linux_conda_3_6_cpu_devtoolset7_nightly build_environment: "conda 3.6 cpu devtoolset7" diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index acb10b2e5f48..267a6eed7855 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -77,9 +77,7 @@ TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/u # from scratch case "$image" in pytorch-linux-xenial-py3.8) - # TODO: This is a hack, get rid of this as soon as you get rid of the travis downloads - TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/ubuntu/16.04/x86_64" - TRAVIS_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.8 GCC_VERSION=7 # Do not install PROTOBUF, DB, and VISION as a test ;; @@ -362,7 +360,6 @@ docker build \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ --build-arg "CLANG_VERSION=${CLANG_VERSION}" \ --build-arg "ANACONDA_PYTHON_VERSION=${ANACONDA_PYTHON_VERSION}" \ - --build-arg "TRAVIS_PYTHON_VERSION=${TRAVIS_PYTHON_VERSION}" \ --build-arg "GCC_VERSION=${GCC_VERSION}" \ --build-arg "CUDA_VERSION=${CUDA_VERSION}" \ --build-arg "CUDNN_VERSION=${CUDNN_VERSION}" \ @@ -405,19 +402,6 @@ if [[ "$OS" == "ubuntu" ]]; then fi fi -if [ -n "$TRAVIS_PYTHON_VERSION" ]; then - if [[ "$TRAVIS_PYTHON_VERSION" != nightly ]]; then - if !(drun python --version 2>&1 | grep -qF "Python $TRAVIS_PYTHON_VERSION"); then - echo "TRAVIS_PYTHON_VERSION=$TRAVIS_PYTHON_VERSION, but:" - drun python --version - exit 1 - fi - else - echo "Please manually check nightly is OK:" - drun python --version - fi -fi - if [ -n "$ANACONDA_PYTHON_VERSION" ]; then if !(drun python --version 2>&1 | grep -qF "Python $ANACONDA_PYTHON_VERSION"); then echo "ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION, but:" diff --git a/.circleci/docker/centos-rocm/Dockerfile b/.circleci/docker/centos-rocm/Dockerfile index 1bc7b0deea32..a94a7167a7f4 100644 --- a/.circleci/docker/centos-rocm/Dockerfile +++ b/.circleci/docker/centos-rocm/Dockerfile @@ -27,7 +27,7 @@ RUN rm install_glibc.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/docker/common/install_base.sh b/.circleci/docker/common/install_base.sh index bd8ca7c40109..191b4732452d 100755 --- a/.circleci/docker/common/install_base.sh +++ b/.circleci/docker/common/install_base.sh @@ -18,7 +18,6 @@ install_ubuntu() { # Install common dependencies apt-get update # TODO: Some of these may not be necessary - # TODO: libiomp also gets installed by conda, aka there's a conflict ccache_deps="asciidoc docbook-xml docbook-xsl xsltproc" numpy_deps="gfortran" apt-get install -y --no-install-recommends \ @@ -40,10 +39,6 @@ install_ubuntu() { libjpeg-dev \ libasound2-dev \ libsndfile-dev \ - python \ - python-dev \ - python-setuptools \ - python-wheel \ software-properties-common \ sudo \ wget \ diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index db8f1a457ecf..c63e28029f07 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -96,13 +96,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # TODO: This isn't working atm conda_install nnpack -c killeent - # Install some other packages + # Install some other packages, including those needed for Python test reporting # TODO: Why is scipy pinned # numba & llvmlite is pinned because of https://github.com/numba/numba/issues/4368 # scikit-learn is pinned because of # https://github.com/scikit-learn/scikit-learn/issues/14485 (affects gcc 5.5 # only) - as_jenkins pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0 + as_jenkins pip install --progress-bar off pytest scipy==1.1.0 scikit-learn==0.20.3 scikit-image librosa>=0.6.2 psutil numba==0.46.0 llvmlite==0.30.0 unittest-xml-reporting coverage popd fi diff --git a/.circleci/docker/common/install_travis_python.sh b/.circleci/docker/common/install_travis_python.sh deleted file mode 100755 index 41ad2dd32eb4..000000000000 --- a/.circleci/docker/common/install_travis_python.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -set -ex - -as_jenkins() { - # NB: Preserve PATH and LD_LIBRARY_PATH changes - sudo -H -u jenkins env "PATH=$PATH" "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" $* -} - -if [ -n "$TRAVIS_PYTHON_VERSION" ]; then - - mkdir -p /opt/python - chown jenkins:jenkins /opt/python - - # Download Python binary from Travis - pushd tmp - as_jenkins wget --quiet ${TRAVIS_DL_URL_PREFIX}/python-$TRAVIS_PYTHON_VERSION.tar.bz2 - # NB: The tarball also comes with /home/travis virtualenv that we - # don't care about. (Maybe we should, but we've worked around the - # "how do I install to python" issue by making this entire directory - # user-writable "lol") - # NB: Relative ordering of opt/python and flags matters - as_jenkins tar xjf python-$TRAVIS_PYTHON_VERSION.tar.bz2 --strip-components=2 --directory /opt/python opt/python - popd - - echo "/opt/python/$TRAVIS_PYTHON_VERSION/lib" > /etc/ld.so.conf.d/travis-python.conf - ldconfig - sed -e 's|PATH="\(.*\)"|PATH="/opt/python/'"$TRAVIS_PYTHON_VERSION"'/bin:\1"|g' -i /etc/environment - export PATH="/opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH" - - python --version - pip --version - - # Install pip from source. - # The python-pip package on Ubuntu Trusty is old - # and upon install numpy doesn't use the binary - # distribution, and fails to compile it from source. - pushd tmp - as_jenkins curl -L -O https://pypi.python.org/packages/11/b6/abcb525026a4be042b486df43905d6893fb04f05aac21c32c638e939e447/pip-9.0.1.tar.gz - as_jenkins tar zxf pip-9.0.1.tar.gz - pushd pip-9.0.1 - as_jenkins python setup.py install - popd - rm -rf pip-9.0.1* - popd - - # Install pip packages - as_jenkins pip install --upgrade pip - - pip --version - - as_jenkins pip install numpy pyyaml - - as_jenkins pip install \ - future \ - hypothesis \ - protobuf \ - pytest \ - pillow \ - typing \ - dataclasses - - as_jenkins pip install mkl mkl-devel - - # SciPy does not support Python 3.7 or Python 2.7.9 - if [[ "$TRAVIS_PYTHON_VERSION" != nightly ]] && [[ "$TRAVIS_PYTHON_VERSION" != "2.7.9" ]]; then - as_jenkins pip install scipy==1.1.0 scikit-image librosa>=0.6.2 - fi - - # Install psutil for dataloader tests - as_jenkins pip install psutil - - # Install dill for serialization tests - as_jenkins pip install "dill>=0.3.1" - - # Cleanup package manager - apt-get autoclean && apt-get clean - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -fi diff --git a/.circleci/docker/ubuntu-cuda/Dockerfile b/.circleci/docker/ubuntu-cuda/Dockerfile index d3a9027d5f06..f512180f1616 100644 --- a/.circleci/docker/ubuntu-cuda/Dockerfile +++ b/.circleci/docker/ubuntu-cuda/Dockerfile @@ -24,7 +24,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -40,12 +40,6 @@ ARG CLANG_VERSION ADD ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh -# Install non-standard Python versions (via Travis binaries) -ARG TRAVIS_PYTHON_VERSION -ENV PATH /opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH -ADD ./common/install_travis_python.sh install_travis_python.sh -RUN bash ./install_travis_python.sh && rm install_travis_python.sh - # (optional) Install protobuf for ONNX ARG PROTOBUF ADD ./common/install_protobuf.sh install_protobuf.sh diff --git a/.circleci/docker/ubuntu-rocm/Dockerfile b/.circleci/docker/ubuntu-rocm/Dockerfile index 5fd133d08245..761bf0438d7f 100644 --- a/.circleci/docker/ubuntu-rocm/Dockerfile +++ b/.circleci/docker/ubuntu-rocm/Dockerfile @@ -21,7 +21,7 @@ RUN bash ./install_clang.sh && rm install_clang.sh ADD ./common/install_user.sh install_user.sh RUN bash ./install_user.sh && rm install_user.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh diff --git a/.circleci/docker/ubuntu/Dockerfile b/.circleci/docker/ubuntu/Dockerfile index ca4d3c58dbc6..72f2c108ff11 100644 --- a/.circleci/docker/ubuntu/Dockerfile +++ b/.circleci/docker/ubuntu/Dockerfile @@ -33,7 +33,7 @@ ARG KATEX ADD ./common/install_katex.sh install_katex.sh RUN bash ./install_katex.sh && rm install_katex.sh -# Install conda +# Install conda and other packages (e.g., numpy, coverage, pytest) ENV PATH /opt/conda/bin:$PATH ARG ANACONDA_PYTHON_VERSION ADD ./common/install_conda.sh install_conda.sh @@ -48,13 +48,6 @@ RUN bash ./install_gcc.sh && rm install_gcc.sh ADD ./common/install_lcov.sh install_lcov.sh RUN bash ./install_lcov.sh && rm install_lcov.sh -# Install non-standard Python versions (via Travis binaries) -ARG TRAVIS_PYTHON_VERSION -ARG TRAVIS_DL_URL_PREFIX -ENV PATH /opt/python/$TRAVIS_PYTHON_VERSION/bin:$PATH -ADD ./common/install_travis_python.sh install_travis_python.sh -RUN bash ./install_travis_python.sh && rm install_travis_python.sh - # (optional) Install protobuf for ONNX ARG PROTOBUF ADD ./common/install_protobuf.sh install_protobuf.sh diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 868f32fd49fa..f6f37dbb0470 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -217,9 +217,11 @@ jobs: echo "Retrieving test reports" docker cp $id:/var/lib/jenkins/workspace/test/test-reports ./ || echo 'No test reports found!' if [[ ${BUILD_ENVIRONMENT} == *"coverage"* ]]; then - echo "Retrieving coverage report" + echo "Retrieving Python coverage report" docker cp $id:/var/lib/jenkins/workspace/test/.coverage ./test docker cp $id:/var/lib/jenkins/workspace/test/coverage.xml ./test + echo "Retrieving C++ coverage report" + docker cp $id:/var/lib/jenkins/workspace/build/coverage.info ./test python3 -mpip install codecov python3 -mcodecov fi diff --git a/.flake8 b/.flake8 index 7ecc6df31754..8be8496e4224 100644 --- a/.flake8 +++ b/.flake8 @@ -12,5 +12,5 @@ ignore = B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -per-file-ignores = __init__.py: F401 +per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi,.git,build,build_test_custom_build,build_code_analyzer diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3e6bb17375b9..8fdccf101af7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -72,8 +72,7 @@ jobs: set -eux pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 flake8 --version - flake8 > ${GITHUB_WORKSPACE}/flake8-output.txt - cat ${GITHUB_WORKSPACE}/flake8-output.txt + flake8 | tee ${GITHUB_WORKSPACE}/flake8-output.txt - name: Add annotations uses: pytorch/add-annotations-github-action@master with: diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 03583f3805c7..b7f0d71107b1 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -168,10 +168,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then - # default pip version is too old(9.0.2), unable to support tag `manylinux2010`. - # Fix the pip error: Couldn't find a version that satisfies the requirement - pip install --upgrade pip - pip install -q --user ort-nightly==1.5.0.dev202009182 + pip install -q --user onnxruntime==1.5.2 fi "$ROOT_DIR/scripts/onnx/test.sh" fi diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index 3e197d867b6e..b94e797e7010 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -42,6 +42,11 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then nvcc --version fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + # enable build option in CMake + export USE_CPP_CODE_COVERAGE=ON +fi + # TODO: Don't run this... pip_install -r requirements.txt || true diff --git a/.jenkins/pytorch/codegen-test.sh b/.jenkins/pytorch/codegen-test.sh new file mode 100755 index 000000000000..3b75999ceb2e --- /dev/null +++ b/.jenkins/pytorch/codegen-test.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +# This script can also be used to test whether your diff changes any codegen output. +# +# Run it before and after your change: +# .jenkins/pytorch/codegen-test.sh +# .jenkins/pytorch/codegen-test.sh +# +# Then run diff to compare the generated files: +# diff -Naur + +set -eu -o pipefail + +if [ "$#" -eq 0 ]; then + COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" + source "$(dirname "${BASH_SOURCE[0]}")/common.sh" + OUT="$(dirname "${BASH_SOURCE[0]}")/../../codegen_result" +else + OUT=$1 +fi + +set -x + +rm -rf "$OUT" + +# aten codegen +python -m tools.codegen.gen \ + -d "$OUT"/torch/share/ATen + +# torch codegen +python -m tools.setup_helpers.generate_code \ + --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --install_dir "$OUT" + +# pyi codegen +mkdir -p "$OUT"/pyi/torch/_C +mkdir -p "$OUT"/pyi/torch/nn +python -m tools.pyi.gen_pyi \ + --declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \ + --out "$OUT"/pyi + +# autograd codegen (called by torch codegen but can run independently) +python -m tools.autograd.gen_autograd \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/autograd \ + tools/autograd + +# unboxing_wrappers codegen (called by torch codegen but can run independently) +mkdir -p "$OUT"/unboxing_wrappers +python -m tools.jit.gen_unboxing_wrappers \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/unboxing_wrappers \ + tools/jit/templates + +# annotated_fn_args codegen (called by torch codegen but can run independently) +mkdir -p "$OUT"/annotated_fn_args +python -m tools.autograd.gen_annotated_fn_args \ + "$OUT"/torch/share/ATen/Declarations.yaml \ + "$OUT"/annotated_fn_args \ + tools/autograd diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 1af459ab8cc8..88bcfc93e19d 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -11,17 +11,13 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" -if [ -n "${IN_CI}" ]; then - # TODO move this to docker - pip_install unittest-xml-reporting coverage pytest +if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then + export PYTORCH_TEST_WITH_SLOW=1 + export PYTORCH_TEST_SKIP_FAST=1 +fi - if [[ "$BUILD_ENVIRONMENT" == *-slow-* ]]; then - export PYTORCH_TEST_WITH_SLOW=1 - export PYTORCH_TEST_SKIP_FAST=1 - fi - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - export PYTORCH_COLLECT_COVERAGE=1 - fi +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + export PYTORCH_COLLECT_COVERAGE=1 fi if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then @@ -401,10 +397,15 @@ else test_distributed test_benchmarks test_rpc - if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then - pushd test - echo "Generating XML coverage report" - time python -mcoverage xml - popd - fi +fi + +if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then + pushd test + echo "Generating XML coverage report" + time python -mcoverage xml + popd + pushd build + echo "Generating lcov coverage report for C++ sources" + time lcov --capture --directory . --output-file coverage.info + popd fi diff --git a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat index 34c3698a1307..1e3cfe090abf 100644 --- a/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.jenkins/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -39,7 +39,7 @@ if %errorlevel% neq 0 ( exit /b %errorlevel% ) popd :: The version is fixed to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 -pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest +pip install "ninja==1.10.0.post1" future "hypothesis==4.53.2" "librosa>=0.6.2" psutil pillow unittest-xml-reporting pytest coverage if %errorlevel% neq 0 ( exit /b %errorlevel% ) :: No need to install faulthandler since we only test Python >= 3.6 on Windows :: faulthandler is builtin since Python 3.3 diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index abcd5756d747..adf9b4c82620 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -14,6 +14,10 @@ fi export TMP_DIR="${PWD}/build/win_tmp" export TMP_DIR_WIN=$(cygpath -w "${TMP_DIR}") +export PROJECT_DIR="${PWD}" +export PROJECT_DIR_WIN=$(cygpath -w "${PROJECT_DIR}") +export TEST_DIR="${PWD}/test" +export TEST_DIR_WIN=$(cygpath -w "${TEST_DIR}") export PYTORCH_FINAL_PACKAGE_DIR="/c/users/circleci/workspace/build-results" export PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}") @@ -45,6 +49,7 @@ run_tests() { $SCRIPT_HELPERS_DIR/test_libtorch.bat else if [[ "${JOB_BASE_NAME}" == *-test1 ]]; then + export PYTORCH_COLLECT_COVERAGE=1 $SCRIPT_HELPERS_DIR/test_python_nn.bat "$DETERMINE_FROM" && \ $SCRIPT_HELPERS_DIR/test_libtorch.bat if [[ "${USE_CUDA}" == "1" ]]; then @@ -59,3 +64,16 @@ run_tests() { } run_tests && assert_git_not_dirty && echo "TEST PASSED" + +if [[ "${BUILD_ENVIRONMENT}" == "pytorch-win-vs2019-cuda10-cudnn7-py3" ]] && [[ "${JOB_BASE_NAME}" == *-test1 ]]; then + pushd $TEST_DIR + python -mpip install coverage + echo "Generating XML coverage report" + time python -mcoverage xml + popd + + pushd $PROJECT_DIR + python -mpip install codecov + python -mcodecov + popd +fi diff --git a/BUILD.bazel b/BUILD.bazel index 9eced9b2c563..4ec99d770f70 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -126,18 +126,13 @@ genrule( outs = [ "aten/src/ATen/Declarations.yaml", "aten/src/ATen/BackendSelectRegister.cpp", - "aten/src/ATen/CPUType.h", "aten/src/ATen/CPUType.cpp", "aten/src/ATen/Functions.h", "aten/src/ATen/Functions.cpp", "aten/src/ATen/NativeFunctions.h", - "aten/src/ATen/MkldnnCPUType.h", "aten/src/ATen/MkldnnCPUType.cpp", - "aten/src/ATen/QuantizedCPUType.h", "aten/src/ATen/QuantizedCPUType.cpp", - "aten/src/ATen/SparseCPUType.h", "aten/src/ATen/SparseCPUType.cpp", - "aten/src/ATen/TypeDefault.h", "aten/src/ATen/TypeDefault.cpp", "aten/src/ATen/core/TensorBody.h", "aten/src/ATen/core/TensorMethods.cpp", diff --git a/NOTICE b/NOTICE index 020beaea4c46..5abaac479a75 100644 --- a/NOTICE +++ b/NOTICE @@ -284,6 +284,112 @@ Apache License Version 2.0: incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 396acc9e0403..9b39006b106e 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -156,11 +156,7 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef( first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); // Without a shape-checking API, we're unable to compute the correct shape of // the output so we just error out. TORCH_CHECK(num_batches > 0, @@ -293,11 +289,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta auto num_batch_dims = input_physical_views.front().numBatchDims(); auto some_sizes = input_physical_views.front().tensor().sizes(); auto batch_sizes = ArrayRef(some_sizes.begin(), some_sizes.begin() + num_batch_dims); - auto num_batches = std::accumulate( - batch_sizes.begin(), - batch_sizes.end(), - 1, - std::multiplies()); + const auto num_batches = prod_intlist(batch_sizes); // Without a shape-checking API, we're unable to compute the correct shape of // the output so we just error out. TORCH_CHECK(num_batches > 0, diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 029a5be521f7..0b180b5059d1 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -222,7 +222,7 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { VmapDimVector all_dims_physical; all_dims_physical.reserve(self_physical.tensor().dim()); for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) { - all_dims_physical.push_back(bdim); + all_dims_physical.push_back(bdim); } all_dims_physical.insert( all_dims_physical.end(), diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index bf9029ff6c6b..ff38f9f2086a 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -443,6 +443,8 @@ endif() list(APPEND ATen_MOBILE_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/tensor_add.cpp) +list(APPEND ATen_MOBILE_BENCHMARK_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/quantize_per_channel.cpp) list(APPEND ATen_MOBILE_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/stateful_conv1d.cpp) diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index f0a55470cc1c..5fcb5ede9cc5 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include @@ -782,50 +780,7 @@ Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max) { } return result; } -Tensor _th_trace(const Tensor & self) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THByteTensor_trace(self_)), options(ScalarType::Byte)); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THCharTensor_trace(self_)), options(ScalarType::Char)); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THDoubleTensor_trace(self_)), options(ScalarType::Double)); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THFloatTensor_trace(self_)), options(ScalarType::Float)); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THIntTensor_trace(self_)), options(ScalarType::Int)); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THLongTensor_trace(self_)), options(ScalarType::Long)); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_trace", false, DeviceType::CPU, dispatch_scalar_type); - return at::scalar_tensor(convert(THShortTensor_trace(self_)), options(ScalarType::Short)); - break; - } - default: - AT_ERROR("_th_trace not supported on CPUType for ", dispatch_scalar_type); - } -} std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 1bc9b66777bc..1aca02539311 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include @@ -38,7 +36,6 @@ Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm); Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scalar min, Scalar max); Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max); -Tensor _th_trace(const Tensor & self); std::tuple _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A); std::tuple _th_gels(const Tensor & self, const Tensor & A); std::tuple _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors); diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index 20717ad43e6f..7b3be6db3d77 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -1,7 +1,5 @@ #pragma once -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h - #include #include #include diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index d691fec1aa34..e6726602bbd5 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -46,6 +46,12 @@ inline C10_HOST_DEVICE bool _isnan(T val) { } +template ::value, int>::type = 0> +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { return at::_isnan(static_cast(val)); } diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index a4007c3115dc..b00f80d232db 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -21,7 +21,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { // public constructor for now... OpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp new file mode 100644 index 000000000000..7a794cb5c312 --- /dev/null +++ b/aten/src/ATen/ScalarOps.cpp @@ -0,0 +1,40 @@ +// FastPass +#ifdef _MSC_VER +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#endif + +#include +#include +#include + +namespace at { +namespace { +template +inline void fill_inplace(Tensor& self, Scalar value_scalar) { + auto value = value_scalar.to(); + scalar_t* dptr = static_cast(self.data_ptr()); + *dptr = value; +} +} + +namespace detail { +Tensor& scalar_fill(Tensor& self, Scalar value) { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { + fill_inplace(self, value); + }); + return self; +} + +Tensor scalar_tensor_static(Scalar s, const TensorOptions& options) { + at::tracer::impl::NoTracerDispatchMode tracer_guard; + at::AutoNonVariableTypeMode non_var_type_mode(true); + auto result = at::detail::empty_cpu({}, options); + scalar_fill(result, s); + return result; +} +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/ScalarOps.h b/aten/src/ATen/ScalarOps.h index 8c07a9d618bc..60cee3ea284b 100644 --- a/aten/src/ATen/ScalarOps.h +++ b/aten/src/ATen/ScalarOps.h @@ -4,6 +4,18 @@ #include #include +namespace at { +namespace detail { +// When filling a number to 1-element CPU tensor, we want to skip +// everything but manipulate data ptr directly. +// Ideally this fast pass should be implemented in TensorIterator, +// but we also want to skip compute_types which in not avoidable +// in TensorIterator for now. +Tensor& scalar_fill(Tensor& self, Scalar value); +TORCH_API Tensor scalar_tensor_static(Scalar s, const TensorOptions& options); +} // namespace detail +} // namespace at + // This is in the c10 namespace because we use ADL to find the functions in it. namespace c10 { @@ -11,16 +23,14 @@ namespace c10 { // to implement this without going through Derived Types (which are not part of core). inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { // This is the fast track we have for CPU scalar tensors. - if (device == at::kCPU) { + if (device == at::kCPU && !s.isComplex()) { if (s.isFloatingPoint()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kDouble)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kDouble)); } else if (s.isBoolean()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kBool)); - } else if (s.isComplex()) { - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kComplexDouble)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kBool)); } else { AT_ASSERT(s.isIntegral(false)); - return at::native::scalar_tensor(s, at::device(at::kCPU).dtype(at::kLong)); + return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kLong)); } } if (s.isFloatingPoint()) { diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 3119c81ac8aa..45492d7b212e 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -30,12 +30,12 @@ namespace { // // This means that we allocate a [1,0] size indices tensor and a [0] size // values tensor for such an empty tensor. -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type) : SparseTensorImpl(key_set, data_type , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long)) , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {} -SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta data_type, at::Tensor indices, at::Tensor values) : TensorImpl(key_set, data_type, values.device()) , sparse_dim_(1) , dense_dim_(0) diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index bdccb540734f..b8e6bb26bf7f 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -31,7 +31,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { public: // Public for now... - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); int64_t nnz() const { return values_.size(0); } int64_t sparse_dim() const { return sparse_dim_; } @@ -217,7 +217,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { refresh_numel(); } private: - explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta, at::Tensor indices, at::Tensor values); /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 626e0c73e45e..08588f6a8cdd 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -335,8 +335,7 @@ c10::optional> computeStride( // we use the stride as if it were computed via resize. // This could perhaps be combined with the below code, but the complexity // didn't seem worth it. - int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, - std::multiplies()); + const int64_t numel = prod_intlist(oldshape); if (numel == 0 && oldshape.equals(newshape)) { return oldstride.vec(); } diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/TypeDefault.h similarity index 86% rename from aten/src/ATen/templates/TypeDefault.h rename to aten/src/ATen/TypeDefault.h index fb62c7ba6354..7b5d77ba4d22 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/TypeDefault.h @@ -1,7 +1,5 @@ #pragma once -// ${generated_comment} - #include #include #include @@ -29,8 +27,4 @@ struct Quantizer; // to frontend using ConstQuantizerPtr = const c10::intrusive_ptr&; -namespace TypeDefault { - ${type_method_declarations} -} // namespace TypeDefault - } // namespace at diff --git a/aten/src/ATen/Utils.cpp b/aten/src/ATen/Utils.cpp index ccd4e4ba9f2f..8a4fa37e469e 100644 --- a/aten/src/ATen/Utils.cpp +++ b/aten/src/ATen/Utils.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace at { @@ -12,4 +14,52 @@ int _crash_if_asan(int arg) { return x[0]; } +namespace detail { +// empty_cpu is used in ScalarOps.h, which can be referenced by other ATen files. Since we want to decouple direct referencing native symbols and only access native symbols through dispatching, we move its implementation here. +Tensor empty_cpu( + IntArrayRef size, + const TensorOptions& options, + c10::optional optional_memory_format) { + TORCH_CHECK( + !(options.has_memory_format() && optional_memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + const MemoryFormat memory_format = + optional_memory_format.value_or( + options.memory_format_opt().value_or( + MemoryFormat::Contiguous)); + + AT_ASSERT(options.device().type() == DeviceType::CPU); + check_size_nonnegative(size); + + c10::Allocator* allocator; + if (options.pinned_memory()) { + allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); + } else { + allocator = at::getCPUAllocator(); + } + + int64_t nelements = prod_intlist(size); + const caffe2::TypeMeta dtype = options.dtype(); + const int64_t size_bytes = nelements * dtype.itemsize(); + auto storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), + size_bytes, + allocator->allocate(size_bytes), + allocator, + /*resizeable=*/true); + + auto tensor = detail::make_tensor( + std::move(storage_impl), at::DispatchKey::CPU, dtype); + // Default TensorImpl has size [0] + if (size.size() != 1 || size[0] != 0) { + tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); + } + + tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); + + return tensor; +} +} // namespace detail + } // at diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index df0e49920afa..4fe4b632362b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -93,10 +93,18 @@ inline int64_t sum_intlist(ArrayRef list) { return std::accumulate(list.begin(), list.end(), 0ll); } -inline int64_t prod_intlist(ArrayRef list) { - return std::accumulate(list.begin(), list.end(), 1ll, std::multiplies()); +//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow +//the next 2 functions set `init` type to int64_t to avoid overflow. +template::value, int>::type = 0> +inline int64_t prod_intlist(const C &container){ + return std::accumulate(container.begin(), container.end(), static_cast(1), std::multiplies()); } +template::value_type>::value, int>::type = 0> +inline int64_t prod_intlist(Iter begin, Iter end){ + return std::accumulate(begin, end, static_cast(1), std::multiplies()); +} /** * Utility function to static cast input Generator* to * the backend generator type (CPU/CUDAGeneratorImpl etc.) @@ -120,4 +128,18 @@ static inline T* get_generator_or_default(const c10::optional& gen, c return gen.has_value() && gen->defined() ? check_generator(gen) : check_generator(default_gen); } +inline void check_size_nonnegative(IntArrayRef size) { + for (auto x: size) { + TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); + } +} + +namespace detail { +CAFFE2_API +Tensor empty_cpu( + IntArrayRef size, + const TensorOptions& options = {}, + c10::optional memory_format = c10::nullopt); +} // namespace detail + } // at diff --git a/aten/src/ATen/benchmarks/quantize_per_channel.cpp b/aten/src/ATen/benchmarks/quantize_per_channel.cpp new file mode 100644 index 000000000000..b9a356593706 --- /dev/null +++ b/aten/src/ATen/benchmarks/quantize_per_channel.cpp @@ -0,0 +1,85 @@ +#include +#include + +#include + +static void quantize_per_channel_4d_contiguous(benchmark::State& state) { + const size_t batches = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + const size_t height = static_cast(state.range(2)); + const size_t width = static_cast(state.range(3)); + + at::Tensor a = at::rand({batches, channels, height, width}); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 1, at::ScalarType::QUInt8); + } +} + +static void quantize_per_channel_4d_channels_last(benchmark::State& state) { + const size_t batches = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + const size_t height = static_cast(state.range(2)); + const size_t width = static_cast(state.range(3)); + + at::Tensor a = at::rand( + {batches, channels, height, width}, + at::TensorOptions().memory_format(at::MemoryFormat::ChannelsLast)); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 1, at::ScalarType::QUInt8); + } +} + +static void quantize_per_channel_2d(benchmark::State& state) { + const size_t channels = static_cast(state.range(0)); + const size_t nelem = static_cast(state.range(1)); + + at::Tensor a = at::rand({channels, nelem}); + at::Tensor scales = at::rand({channels}); + at::Tensor zero_points = at::randint( + 0, 10, {channels}, at::TensorOptions().dtype(at::ScalarType::Int)); + + at::Tensor qa; + for (auto _ : state) { + qa = at::native::quantize_per_channel_cpu( + a, scales, zero_points, 0, at::ScalarType::QUInt8); + } +} + +static void GenerateSizes4d(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "C", "H", "W"}); + + for (size_t n = 16; n < 256; n *= 2) { + for (size_t c = 4; c < 256; c *= 2) { + for (size_t hw = 4; hw < 256; hw *= 2) { + b->Args({n, c, hw, hw}); + } + } + } +} + +static void GenerateSizes2d(benchmark::internal::Benchmark* b) { + b->ArgNames({"C", "N"}); + + for (size_t c = 4; c < 512; c *= 2) { + for (size_t n = 4; n < 512; n *= 2) { + b->Args({c, n}); + } + } +} + +BENCHMARK(quantize_per_channel_2d)->Apply(GenerateSizes2d); +BENCHMARK(quantize_per_channel_4d_contiguous)->Apply(GenerateSizes4d); +BENCHMARK(quantize_per_channel_4d_channels_last)->Apply(GenerateSizes4d); +BENCHMARK_MAIN(); diff --git a/aten/src/ATen/core/Dimname.h b/aten/src/ATen/core/Dimname.h index d81cdfef34e7..8010614c54f0 100644 --- a/aten/src/ATen/core/Dimname.h +++ b/aten/src/ATen/core/Dimname.h @@ -21,7 +21,7 @@ struct CAFFE2_API Dimname { bool isWildcard() const { return type_ == NameType::WILDCARD; } bool matches(Dimname other) const; - optional unify(Dimname other) const; + c10::optional unify(Dimname other) const; private: Dimname(Symbol name) diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 33e4ebcfc7dc..187b217604ba 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -210,6 +210,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("i0", CppFunction::makeFallthrough()); m.impl("i0.out", CppFunction::makeFallthrough()); m.impl("i0_", CppFunction::makeFallthrough()); + m.impl("igamma", CppFunction::makeFallthrough()); + m.impl("igamma.out", CppFunction::makeFallthrough()); + m.impl("igamma_", CppFunction::makeFallthrough()); m.impl("imag", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index 6efd0fe1f61a..b67e24aa26fe 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -99,7 +99,7 @@ void check_names_valid_for(const Tensor& tensor, DimnameList names); void check_names_valid_for(size_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. -CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); +CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, c10::optional names); CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, std::vector&& names, bool validate_names); constexpr size_t kMaxNamedTensorDim = 64; @@ -110,7 +110,7 @@ namespace impl { // Some helper functions on TensorImpl. Useful for working with names in TH. // XXX: Ideally these would exist as methods on TensorImpl -CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional names, bool validate_names); +CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, c10::optional names, bool validate_names); CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); void check_names_valid_for(TensorImpl* impl, DimnameList names); @@ -131,7 +131,7 @@ CAFFE2_API DimnameList get_names(const TensorImpl* impl); // Returns the names of the tensor if they have been allocated; returns nullopt // instead if the haven't been. The names of a tensor are not allocated if a // tensor is constructed with names=None. -CAFFE2_API optional get_opt_names(const TensorImpl* impl); +CAFFE2_API c10::optional get_opt_names(const TensorImpl* impl); } // namespace impl diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index da259e82990a..e84ad93de37d 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -371,6 +371,8 @@ _(aten, hstack) \ _(aten, hypot) \ _(aten, i0) \ _(aten, i0_) \ +_(aten, igamma) \ +_(aten, igamma_) \ _(aten, ifft) \ _(aten, index) \ _(aten, index_add) \ @@ -736,7 +738,6 @@ _(aten, vander) \ _(aten, var) \ _(aten, view) \ _(aten, view_as) \ -_(aten, vstack) \ _(aten, where) \ _(aten, zero) \ _(aten, zeros) \ @@ -781,6 +782,7 @@ _(attr, ceil_mode) \ _(attr, checked_signal_sizes) \ _(attr, chunks) \ _(attr, columns) \ +_(attr, column_stack) \ _(attr, complex_input) \ _(attr, complex_output) \ _(attr, condition) \ diff --git a/aten/src/ATen/core/blob.h b/aten/src/ATen/core/blob.h index 988e99b2395e..3b6bafa12e62 100644 --- a/aten/src/ATen/core/blob.h +++ b/aten/src/ATen/core/blob.h @@ -51,7 +51,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { /** * Returns the meta info of the blob. */ - const TypeMeta& meta() const noexcept { + const TypeMeta meta() const noexcept { return meta_; } @@ -155,7 +155,7 @@ class CAFFE2_API Blob final : public c10::intrusive_ptr_target { TypeMeta::Make::type>())); } - void* ShareExternal(void* allocated, const TypeMeta& meta) { + void* ShareExternal(void* allocated, const TypeMeta meta) { free_(); meta_ = meta; pointer_ = allocated; diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 3ae57341fcf8..a5f9354d7ca2 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -388,9 +388,9 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle::boxArgs(args...); - guard.before(op.schema().name(), stack, seq_num); + guard.before(op, stack, seq_num); } else { - guard.before(op.schema().name(), seq_num); + guard.before(op, seq_num); } } } @@ -438,9 +438,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const seq_num = at::sequence_number::peek(); } if (guard.needs_inputs) { - guard.before(op.schema().name(), *stack, seq_num); + guard.before(op, *stack, seq_num); } else { - guard.before(op.schema().name(), seq_num); + guard.before(op, seq_num); } } } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 97651c9865a1..a3cef61f4c21 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -83,13 +83,15 @@ std::list::iterator OperatorEntry::registerKernel( // that would also invalidate the old TypedOperatorHandles. if (cpp_signature.has_value()) { if (cpp_signature_.has_value()) { - TORCH_INTERNAL_ASSERT(*cpp_signature == *cpp_signature_, - "Tried to register a kernel (", debug, ") for operator ", name_," for dispatch key ", toString(dispatch_key), - ", but the C++ function signature ", cpp_signature->name(), " mismatched with a previous kernel that had the signature ", - cpp_signature_->name() + TORCH_INTERNAL_ASSERT(*cpp_signature == cpp_signature_->signature, + "Tried to register a kernel (", debug, ") for operator ", name_," (", + (this->schema_.has_value() ? this->schema_->debug : "no debug info"), + ") for dispatch key ", toString(dispatch_key), ", but the C++ function signature ", + cpp_signature->name(), " mismatched with a previous kernel (", cpp_signature_->debug, + ") that had the signature ", cpp_signature_->signature.name() ); } else { - cpp_signature_ = *cpp_signature; + cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug }; } } @@ -103,7 +105,12 @@ std::list::iterator OperatorEntry::registerKernel( auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math]; if (k.size() > 0) { - TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " for dispatch key ", toString(dispatch_key), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); + TORCH_WARN("Registering a kernel (", debug, ") for operator ", name_, " (", + (this->schema_.has_value() ? this->schema_->debug : "no debug info"), + ") for dispatch key ", toString(dispatch_key), + " that overwrote a previously registered kernel (", + (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), + ") with the same dispatch key for the same operator."); } if (manuallyBoxedKernel_.has_value()) { @@ -377,7 +384,11 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const { } TORCH_CHECK(false, "Could not run '", name_, "' with arguments", - " from the '", toString(dispatchKey), "' backend. '", + " from the '", toString(dispatchKey), "' backend. This could be because " + "the operator doesn't exist for this backend, or was omitted during ", + "the selective/custom build process (if using custom build). If you are a ", + "Facebook employee using PyTorch on mobile, please visit ", + "https://fburl.com/ptmfixes for possible resolutions. '", name_, "' is only available for these backends: ", listAllDispatchKeys(), ".\n\n", dumpComputedTable()); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index ed4d5f40b97f..26506cb0f76f 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -157,13 +157,15 @@ class CAFFE2_API OperatorEntry final { // Asserts that the given FuncType is correct for calling this operator in an unboxed way. template void assertSignatureIsCorrect() { - TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == *cpp_signature_), + TORCH_INTERNAL_ASSERT(!cpp_signature_.has_value() || (CppSignature::make() == cpp_signature_->signature), "Tried to access operator ", name_, " with a wrong signature. Accessed with ", CppSignature::make().name(), " but the operator was registered with ", - cpp_signature_->name(), - " (", + cpp_signature_->signature.name(), + " (schema: ", (schema_.has_value() ? schema_->debug : "unknown debug info"), + ", kernel: ", + cpp_signature_->debug, ") This likely happened in a call to OperatorHandle::typed(). Please make sure that the function signature matches the signature in the operator registration call." ); } @@ -230,12 +232,17 @@ class CAFFE2_API OperatorEntry final { AnnotatedKernel missingKernel_; static const AnnotatedKernel ambiguousAutogradOtherKernel_; - // signature_hash_ is set to the hash of the function signature if any of + // cpp_signature_ stores function signature if any of // the kernels was created in a way that allowed us to know the function // signature (i.e. by supplying an unboxed C++ kernel function). - // If this is set, it will be used in unboxed function calls + // If this is set, it will be used to check that future kernel + // registrations match and it will be used in unboxed function calls // to verify their arguments against the known function signature. - c10::optional cpp_signature_; + struct CppSignatureWithDebug { + CppSignature signature; + std::string debug; + }; + c10::optional cpp_signature_; // Whether this operator needs to be observed with RecordFunction const bool is_observed_; diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index dd099be59dff..c29ff15c2b59 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -56,7 +56,7 @@ namespace c10 { _(prim, ReturnStmt) \ _(prim, BreakStmt) \ _(prim, ContinueStmt) \ - _(prim, LocalVariableScope) \ + _(prim, ListComprehensionScope) \ _(prim, Store) \ _(prim, AutogradZero) \ _(prim, AutogradAnyNonZero) \ @@ -70,6 +70,7 @@ namespace c10 { _(prim, ListConstruct) \ _(prim, ListUnpack) \ _(prim, DictConstruct) \ + _(prim, ModuleDictIndex) \ _(prim, EnumName) \ _(prim, EnumValue) \ _(prim, StringIndex) \ @@ -129,7 +130,7 @@ namespace c10 { _(prim, fork) \ _(prim, forkClosure) \ _(prim, RaiseException) \ - _(prim, Function) \ + _(prim, Closure) \ _(prim, CreateObject) \ _(prim, SetAttr) \ _(prim, GetAttr) \ @@ -268,6 +269,8 @@ namespace c10 { _(aten, bin) \ _(aten, pop) \ _(aten, insert) \ + _(aten, vstack) \ + _(aten, row_stack) \ _(prim, unchecked_unwrap_optional) \ _(aten, __contains__) \ _(prim, BailoutTemplate) \ diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 06900064e266..c1da20221e62 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -2102,6 +2102,13 @@ struct CAFFE2_API ClassType : public NamedType { // valid again. void unsafeRemoveAttribute(const std::string& name); + // [Internal Only] Change the type of an attribute of the ClassType, + // The caller is responsible to make sure the modification is safe: + // it is unsafe to maintain uses of the old type of the attribute, + // and any code that works on the attribute is now invalid. + // Only newly created code is valid again. + void unsafeChangeAttributeType(const std::string& name, TypePtr new_ty); + // Add attribute \p NAME if it doesn't exist or verify that it has a // compatible type otherwise. size_t addOrCheckAttribute( diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 93b0ffc1b88e..3fd8740d1ab1 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -310,7 +310,8 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis std::string output = testing::internal::GetCapturedStderr(); EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); EXPECT_THAT(output, testing::HasSubstr("CPU")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); + EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); } TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegisteringInSameOpCall_thenFails) { @@ -348,7 +349,8 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegistering_then std::string output = testing::internal::GetCapturedStderr(); EXPECT_THAT(output, testing::HasSubstr("_test::dummy")); EXPECT_THAT(output, testing::HasSubstr("catch all")); - EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel with the same dispatch key for the same operator")); + EXPECT_THAT(output, testing::HasSubstr("overwrote a previously registered kernel ")); + EXPECT_THAT(output, testing::HasSubstr(" with the same dispatch key for the same operator")); } TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenRegisteringInSameOpCall_thenFails) { @@ -701,7 +703,7 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(c10::DispatchKey::CPU) .kernel(c10::DispatchKey::CUDA, &called_kernel)); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) { @@ -944,7 +946,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBackendWithMismatchingCppSignatures_thenFails) { @@ -953,7 +955,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringCatchAllAndBacke expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .kernel(DispatchKey::CPU, [] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchAllWithMismatchingCppSignatures_thenFails) { @@ -962,7 +964,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringBackendAndCatchA expectThrows([] { auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() .catchAllKernel([] (const int64_t&) {})); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenLambdaKernel_whenAccessingWithMismatchingCppSignatures_thenFails) { @@ -989,7 +991,7 @@ TEST(OperatorRegistrationTest, givenTorchLibrary_whenRegisteringWithMismatchingC m.impl("dummy", DispatchKey::CPU, [] (int64_t) {}); expectThrows([&] { m.impl("dummy", DispatchKey::CUDA, [] (const int64_t&) {}); - }, "mismatched with a previous kernel that had the signature"); + }, "mismatched with a previous kernel"); } TEST(OperatorRegistrationTest, givenTorchLibrary_whenAccessingWithMismatchingCppSignatures_thenFails) { diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h index c8437e924a3c..26d5533244d7 100644 --- a/aten/src/ATen/core/op_registration/op_whitelist.h +++ b/aten/src/ATen/core/op_registration/op_whitelist.h @@ -36,7 +36,9 @@ namespace impl { // returns true iff whitelist contains item // op_whitelist_contains("a;bc;d", "bc") == true constexpr bool op_whitelist_contains(string_view whitelist, string_view item) { - size_t next = -1; + //Choose a really big value for next so that if something goes wrong + //this code will blow up in a hopefully detectable way. + size_t next = std::numeric_limits::max(); for (size_t cur = 0; cur <= whitelist.size(); cur = next) { next = whitelist.find(';', cur); if (next != string_view::npos) { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index f5478d040060..67b7899bb22f 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1315,6 +1315,14 @@ void ClassType::unsafeRemoveAttribute(const std::string& name) { AT_ASSERT(attributes_.size() == attributeTypes_.size()); } +void ClassType::unsafeChangeAttributeType(const std::string& name, TypePtr new_ty) { + auto slot = getAttributeSlot(name); + auto old_attr_info = attributes_[slot]; + AT_ASSERT(old_attr_info.getKind() == AttributeKind::REGULAR_ATTRIBUTE); + attributes_[slot] = ClassAttribute(old_attr_info.getKind(), new_ty, old_attr_info.getName()); + attributeTypes_[slot] = new_ty; +} + size_t ClassType::addConstant(const std::string& name, const IValue& value) { checkNotExist(name, "constant"); size_t slot = constantNames_.size(); diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index edce0e3a2cce..807a9d9780f0 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -394,6 +394,13 @@ struct Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + Vec256 ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = calc_igamma(values[i], x[i]); + } + return ret; + } Vec256 neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 37d41676e53c..10bbe139b63f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -290,6 +290,25 @@ template <> class Vec256 { auto o2 = _mm256_loadu_ps(tmp2); return cvtfp32_bf16(o1, o2); } + Vec256 igamma(const Vec256 &x) const { + __m256 lo, hi; + __m256 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } Vec256 log() const { return map(Sleef_logf8_u10); } diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec256/vec256_complex_double.h index d2ae6f46b44e..d7f5afd8b67d 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_double.h @@ -252,6 +252,9 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_pd(); return _mm256_sub_pd(zero, values); diff --git a/aten/src/ATen/cpu/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec256/vec256_complex_float.h index 8b4eba07f421..4df95dbea926 100644 --- a/aten/src/ATen/cpu/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_complex_float.h @@ -290,6 +290,9 @@ template <> class Vec256> { Vec256> hypot(const Vec256> &b) const { AT_ERROR("not supported for complex numbers"); } + Vec256> igamma(const Vec256> &x) const { + AT_ERROR("not supported for complex numbers"); + } Vec256> neg() const { auto zero = _mm256_setzero_ps(); return _mm256_sub_ps(zero, values); diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index fcad154e68b2..6b611e8d2e7a 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -155,6 +155,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ double tmp[size()]; + __at_align32__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return Vec256(Sleef_logd4_u10(values)); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index 1ab11ea81529..d83895fdf854 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -193,6 +193,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 neg() const { return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec256/vec256_float_neon.h index f98c645a08d6..f410e415277f 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec256/vec256_float_neon.h @@ -25,6 +25,8 @@ namespace { // https://bugs.llvm.org/show_bug.cgi?id=45824 // Most likely we will do aarch32 support with inline asm. #if defined(__aarch64__) +// See https://github.com/pytorch/pytorch/issues/47098 +#if defined(__clang__) || (__GNUC__ > 8 || (__GNUC__ == 8 && __GNUC_MINOR__ > 3)) #ifdef __BIG_ENDIAN__ #error "Big endian is not supported." @@ -362,6 +364,16 @@ template <> class Vec256 { Vec256 i0() const { return map(calc_i0); } + Vec256 igamma(const Vec256 &x) const { + __at_align32__ float tmp[size()]; + __at_align32__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } Vec256 log() const { return map(std::log); } @@ -665,6 +677,7 @@ Vec256 inline fmadd(const Vec256& a, const Vec256& b, const return Vec256(r0, r1); } -#endif +#endif /* defined(__clang__) || (__GNUC__ > 8 || (__GNUC__ == 8 && __GNUC_MINOR__ > 3)) */ +#endif /* defined(aarch64) */ }}} diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index b2d8df49f51b..45ceddcd94e8 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -1,7 +1,5 @@ #include -// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.cpp - #include #include #include diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index e7e5659babbb..16f706ca0ed5 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -331,6 +331,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { auto batch_size = batchCount(A); auto n = A.size(-2); auto nrhs = b.size(-1); + auto lda = std::max(int64_t{1}, n); auto ipiv = at::empty({n}, b.options().dtype(kInt)); auto ipiv_data = ipiv.data_ptr(); @@ -339,7 +340,7 @@ static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { for (int64_t i = 0; i < batch_size; i++) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; - lapackSolve(n, nrhs, A_working_ptr, n, ipiv_data, b_working_ptr, n, &info); + lapackSolve(n, nrhs, A_working_ptr, lda, ipiv_data, b_working_ptr, lda, &info); infos[i] = info; if (info != 0) { return; diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f8af756773c9..b7916ba3f9c8 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -46,6 +46,7 @@ DEFINE_DISPATCH(logaddexp2_stub); DEFINE_DISPATCH(gcd_stub); DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); +DEFINE_DISPATCH(igamma_stub); DEFINE_DISPATCH(nextafter_stub); DEFINE_DISPATCH(heaviside_stub); @@ -968,6 +969,23 @@ Tensor& hypot_(Tensor& self, const Tensor& other) { return at::hypot_out(self, self, other); } +Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return result; +} + +Tensor igamma(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + igamma_stub(iter.device_type(), iter); + return iter.output(); +} + +Tensor& igamma_(Tensor& self, const Tensor& other) { + return at::igamma_out(self, self, other); +} + Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) { auto iter = TensorIterator::binary_op(result, self, other); nextafter_stub(iter.device_type(), iter); diff --git a/aten/src/ATen/native/BinaryOps.h b/aten/src/ATen/native/BinaryOps.h index 7640c8bd84ac..ee3f023fedc5 100644 --- a/aten/src/ATen/native/BinaryOps.h +++ b/aten/src/ATen/native/BinaryOps.h @@ -10,7 +10,7 @@ namespace at { namespace native { inline void alpha_check(const ScalarType dtype, Scalar alpha) { TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool, "Boolean alpha only supported for Boolean results."); - TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) + TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) || alpha.isIntegral(true), "For integral input tensors, argument alpha must not be a floating point number."); } @@ -68,6 +68,7 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub); DECLARE_DISPATCH(binary_fn, gcd_stub); DECLARE_DISPATCH(binary_fn, lcm_stub); DECLARE_DISPATCH(binary_fn, hypot_stub); +DECLARE_DISPATCH(binary_fn, igamma_stub); DECLARE_DISPATCH(binary_fn, nextafter_stub); DECLARE_DISPATCH(binary_fn, heaviside_stub); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 0430de87eb77..360069998f19 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -131,7 +132,11 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) } if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) { + #ifdef USE_VULKAN_API + return vulkan::ops::copy_(self, src); + #else return at::vulkan::vulkan_copy_(self, src); + #endif } if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) { diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index dc21a505e8c1..63e2462489be 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -3,7 +3,9 @@ #include #include #include + #include +#include // Implements instruction set specific function dispatch. // diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b2b760513a1d..91d804687290 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -27,7 +27,7 @@ Tensor pdist(const Tensor& self, const double p) { Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) { /** This function does the fist part of the euclidean distance calculation - * We divide it in two steps to simplify dealing with subgradients in the + * We divide it in two steps to simplify dealing with subgradients in the * backward step */ Tensor x1_norm = x1.pow(2).sum(-1, true); Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -74,7 +74,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10 std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies()); + const int64_t expand_batch_product = prod_intlist(expand_batch_portion); std::vector tensor1_view{expand_batch_product, r1, c1}; std::vector tensor2_view{expand_batch_product, r2, c2}; @@ -147,8 +147,10 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c auto device2 = x2.device().type(); TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); IntArrayRef batch_tensor1(x1.sizes().data(), std::max(x1.dim() - 2, 0)); - int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies()); - Tensor grad_x1 = at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT).view({batch_product, n, m}); + const int64_t batch_product = prod_intlist(batch_tensor1); + Tensor grad_x1 = + at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT) + .view({batch_product, n, m}); cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist); return grad_x1; } diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 3f250ae09909..6589a33ed2f4 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -17,16 +17,27 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding", indices_arg, kLong); + auto zerofill_padding = [&](Tensor& embedding) { + if (padding_idx >= 0) { + embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } + }; + // TODO: use tensor.index() after improving perf if (indices.dim() == 1) { - return weight.index_select(0, indices); + auto out = weight.index_select(0, indices); + zerofill_padding(out); + return out; } auto size = indices.sizes().vec(); for (auto d : weight.sizes().slice(1)) { size.push_back(d); } - return weight.index_select(0, indices.reshape(-1)).view(size); + + auto out = weight.index_select(0, indices.reshape(-1)); + zerofill_padding(out); + return out.view(size); } Tensor embedding_backward( diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index 73f7dcd61926..b466ca26fc0c 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -4,20 +4,12 @@ #include #include #include +#include namespace at { namespace native { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -namespace { - template - inline void fill_fast(Tensor& self, Scalar value_scalar) { - auto value = value_scalar.to(); - scalar_t * dptr = static_cast(self.data_ptr()); - *dptr = value; - } -} // namspace - Tensor& fill_out(Tensor& self, Scalar value) { if (self.is_quantized()) { at::Tensor out = at::ones(self.sizes()).to(kFloat) * value; @@ -26,15 +18,8 @@ Tensor& fill_out(Tensor& self, Scalar value) { self.copy_(out); return self; } - // When filling a number to 1-element CPU tensor, we want to skip - // everything but manipulate data ptr directly. - // Ideally this fast pass should be implemented in TensorIterator, - // but we also want to skip compute_types which in not avoidable - // in TensorIterator for now. if (self.device() == at::kCPU && self.numel() == 1 && !self.is_complex() && !value.isComplex()) { - AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() { - fill_fast(self, value);}); - return self; + return at::detail::scalar_fill(self, value); } auto iter = TensorIteratorConfig() .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 0dd727fb3197..8796657dc293 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -675,8 +675,8 @@ Tensor matmul( std::vector tensor2_expand_size(expand_batch_portion); tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p}); - int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), - 1, std::multiplies()); + const int64_t expand_batch_product = + prod_intlist(expand_batch_portion); std::vector tensor1_bmm_view({expand_batch_product}); tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1}); @@ -742,7 +742,7 @@ Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) { {n_copies, a.size(0), a.size(1), a.size(2)}, a.options().memory_format(at::MemoryFormat::Contiguous) ); - + if (is_zero) { res.zero_(); } @@ -850,7 +850,7 @@ Tensor compute_T4(const Tensor& A) { auto As = _allocate_buffer(A, 4); // 3 for {I, A, A^2} _fill_matrix_powers(As, A, 3); - + at::native::matmul( // output for A^2 * (I / 2 + A / 6 + A^2 / 24) As.select(0, 3), @@ -1101,7 +1101,7 @@ Tensor mexp_impl( if (!compute_highest_degree_approx) { constexpr std::array< Tensor(*)(const Tensor&), - total_n_degs - 1> + total_n_degs - 1> compute_Ts = { compute_T1, compute_T2, compute_T4, compute_T8, compute_T12 @@ -1192,7 +1192,7 @@ Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) { // Based on: // -// Mathias, Roy. +// Mathias, Roy. // A Chain Rule for Matrix Functions and Applications. // SIAM J. Matrix Anal. Appl. 17 (1996): 610-620. // @@ -1227,8 +1227,8 @@ Tensor backward_analytic_function_of_a_matrix( // Mathematics 2019, 7, 1174. // Tensor matrix_exp(const Tensor& a) { - TORCH_CHECK(a.dim() >= 2 - && (at::isFloatingType(a.scalar_type()) + TORCH_CHECK(a.dim() >= 2 + && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())), "matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor " "of floating or complex types with dim at least 2"); @@ -1602,6 +1602,55 @@ Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, opt return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); } +Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + /* + The idea is to reduce the problem to 2D matrix solve. + Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right. + For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2), + then the result of permutation would have the shape (2, 4, 1, 3). + Step 2. reshape `self` to 2D matrix. + Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D() + Step 4. reshape the result. + */ + int64_t ndim = self.dim(); + Tensor self_ = self; + + // move dimensions of `self_` from `dims` to the end + if (dims.has_value()) { + DimVector dest_axes(dims.value().size()); + std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size()); + self_ = at::movedim(self_, dims.value(), dest_axes); + } + + // result_shape is self_.sizes[-(an-other.dim):] + std::vector result_shape = self_.sizes().slice(other.dim(), ndim - other.dim()).vec(); + + int64_t result_product = std::accumulate(result_shape.begin(), result_shape.end(), int64_t{1}, std::multiplies()); + int64_t other_product = std::accumulate(other.sizes().begin(), other.sizes().end(), int64_t{1}, std::multiplies()); + + // Check whether the self tensor can be reshaped to the 2D square matrix + TORCH_CHECK(result_product == other_product, + "Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ", + result_product, " != ", other_product); + + self_ = self_.reshape({result_product, result_product}); + + // 0th output of at::solve is the solution + // normally `other` would be flattened by at::solve expects 2D input + Tensor result = std::get<0>(at::solve(other.reshape({other.numel(), 1}), self_)); + return result.reshape(result_shape); +} + +Tensor& linalg_tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::linalg_tensorsolve(self, other, dims); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index c00ffec94119..dc5530a72813 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -381,6 +381,716 @@ static inline float calc_polygamma(int64_t n, float x) { zeta(double(n + 1), x); } +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et. al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +static scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * M_PIf * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +c10::Half calc_igamma(c10::Half a, c10::Half x) { + return calc_igamma(float(a), float(x)); +} + inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } template diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index f8f0231b181c..a7042b283c4c 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -14,7 +14,7 @@ Tensor empty_meta( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + TensorOptions options = options_.merge_memory_format(optional_memory_format); // TODO: deduplicate this logic with empty_cpu diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 459dd857727f..e80b0c546362 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -181,10 +182,8 @@ void slow_conv_dilated_all_cpu_template( // Temporary buffer: Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t n = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t n = prod_intlist(output_size); columns.resize_({nInputPlane * m, n}); } // Initialize diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index b89554fd4d48..071460b090cd 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -28,11 +28,12 @@ static inline T pooling_output_shape_pad_lr( T outputSize = div_rtn( inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + (ceil_mode ? stride - 1 : 0), stride) + 1; - if (pad_l) { + if (ceil_mode) { // ensure that the last pooling starts inside the image // needed to avoid problems in ceil mode - if ((outputSize - 1) * stride >= inputSize + pad_l) + if ((outputSize - 1) * stride >= inputSize + pad_l) { --outputSize; + } } return outputSize; } diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 8fdc977092f4..be61ffb8b546 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -73,7 +73,7 @@ static inline void checkInBoundsForStorage( IntArrayRef size, IntArrayRef stride, int64_t storage_offset, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, const Storage& new_storage) { int64_t storage_size_bytes = detail::computeStorageNbytes(size, stride, data_type.itemsize()); diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 58df4cf110f7..d941f3b8e169 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -29,10 +29,15 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b return self; } + bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() && + (options.layout() == c10::kStrided)); + if (memory_format == MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { // Copy all strides - auto r = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt)); + auto r = at::empty_strided(self.sizes(), + self.strides(), + options.memory_format(c10::nullopt).pinned_memory(pin_out)); r.copy_(self, non_blocking); return r; } else { @@ -40,7 +45,9 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b } } // See Note [Explicit nullopt MemoryFormat argument] - auto r = at::empty(self.sizes(), options.memory_format(memory_format), c10::nullopt); + auto r = at::empty(self.sizes(), + options.memory_format(memory_format).pinned_memory(pin_out), + c10::nullopt); r.copy_(self, non_blocking); return r; } @@ -56,7 +63,7 @@ Tensor to( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK(options.requires_grad_opt() == c10::nullopt, "to(options) expects unset requires_grad flag, but got " diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 82d7363a1b32..0cec2dd32b0e 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -166,44 +166,7 @@ Tensor polar(const Tensor& abs, const Tensor& angle) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional optional_memory_format) { - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); - - AT_ASSERT(options.device().type() == DeviceType::CPU); - check_size_nonnegative(size); - - c10::Allocator* allocator; - if (options.pinned_memory()) { - allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); - } else { - allocator = at::getCPUAllocator(); - } - - int64_t nelements = prod_intlist(size); - auto dtype = options.dtype(); - int64_t size_bytes = nelements * dtype.itemsize(); - auto storage_impl = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size_bytes, - allocator->allocate(size_bytes), - allocator, - /*resizeable=*/true); - - auto tensor = detail::make_tensor( - std::move(storage_impl), at::DispatchKey::CPU, dtype); - // Default TensorImpl has size [0] - if (size.size() != 1 || size[0] != 0) { - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); - } - - auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous); - tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); - - return tensor; + return at::detail::empty_cpu(size, options_, optional_memory_format); } Tensor empty( @@ -277,7 +240,7 @@ Tensor empty_like( TensorOptions options = self.options() .merge_in(options_) - .merge_in(TensorOptions().memory_format(optional_memory_format)); + .merge_memory_format(optional_memory_format); TORCH_CHECK( !(options.layout() != kStrided && @@ -381,7 +344,8 @@ Tensor new_empty( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eye ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor eye(int64_t n, const TensorOptions& options) { - return native::eye(n, -1, options); + // the default value of `m` equals to `n` + return native::eye(n, n, options); } Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { @@ -390,15 +354,13 @@ Tensor eye(int64_t n, int64_t m, const TensorOptions& options) { } Tensor& eye_out_cpu(Tensor& result, int64_t n) { - return native::eye_out_cpu(result, n, -1); + // the default value of `m` equals to `n` + return native::eye_out_cpu(result, n, n); } Tensor& eye_out_cpu(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index f551adcec693..8cae202efe13 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -61,11 +61,7 @@ inline void check_args( } } -inline void check_size_nonnegative(IntArrayRef size) { - for (auto x: size) { - TORCH_CHECK(x >= 0, "Trying to create tensor with negative dimension ", x, ": ", size); - } -} +using at::check_size_nonnegative; inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) { TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(), diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a42e90f399d9..b8bbe1edf8ee 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -368,7 +369,7 @@ static Tensor cat_sparse(TensorList tensors, int64_t dim) { // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting. int64_t values_dim = wrapped - sparse_dim + 1; // The final size along the catted dimension. - int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), 0, [values_dim](int64_t l, Tensor const &r) { + const int64_t total_size = std::accumulate(tensors.begin(), tensors.end(), static_cast(0), [values_dim](int64_t l, Tensor const &r) { return l + r._values().size(values_dim); }); auto zeros_sizes = tensors[0]._values().sizes().vec(); @@ -1262,6 +1263,47 @@ static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t di return self; } +// torch.row_stack, alias for torch.vstack +Tensor& row_stack_out(Tensor& result, TensorList tensors) { + return at::vstack_out(result, tensors); +} + +Tensor row_stack(TensorList tensors) { + return at::vstack(tensors); +} + +static std::vector reshape_input_for_column_stack(TensorList tensors) { + std::vector result(tensors.size()); + auto transform_lambda = [](const Tensor& input) -> Tensor { + // reshape 0D or 1D tensor t into (t.numel(), 1) + if (input.dim() <= 1) { + return input.reshape({input.numel(), 1}); + } + return input; + }; + std::transform(tensors.cbegin(), + tensors.cend(), + result.begin(), + transform_lambda); + return result; +} + +Tensor& column_stack_out(Tensor& result, TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack_out(result, reshaped_tensors); +} + +Tensor column_stack(TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, + "column_stack expects a non-empty TensorList"); + + auto reshaped_tensors = reshape_input_for_column_stack(tensors); + return at::hstack(reshaped_tensors); +} + static Tensor& propagate_transposed_names( Tensor& result, const Tensor& other, @@ -1634,7 +1676,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty"); TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size()); - auto numel = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const int64_t numel = prod_intlist(sizes); if (self.has_names()) { TORCH_CHECK(numel == self.size(dim), "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ", @@ -1943,4 +1985,29 @@ Tensor movedim(const Tensor& self, int64_t src, int64_t dst) { return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst}); } +Tensor trace_cpu(const Tensor& self) { + Tensor result = at::empty({}, self.options()); + AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] { + using accscalar_t = at::acc_type; + accscalar_t sum = 0; + const auto* t_data = self.data_ptr(); + + int64_t t_stride_0, t_stride_1, t_diag_size; + + TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim()); + + t_stride_0 = self.stride(0); + t_stride_1 = self.stride(1); + + t_diag_size = std::min(self.size(0), self.size(1)); + for (int64_t i = 0; i < t_diag_size; i++) { + sum += t_data[i * (t_stride_0 + t_stride_1)]; + } + + *result.data_ptr() = sum; + }); + + return result; +} + }} // at::native diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 518032c81b04..1aebfda85da0 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -339,8 +339,8 @@ Tensor& sin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float Tensor sin(const Tensor& self) { return unary_op_impl_float(self, sin_stub); } Tensor& sin_(Tensor& self) { return unary_op_impl_(self, at::sin_out); } -Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, cos_stub); } -Tensor cos(const Tensor& self) { return unary_op_impl(self, at::cos_out); } +Tensor& cos_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, cos_stub); } +Tensor cos(const Tensor& self) { return unary_op_impl_float(self, cos_stub); } Tensor& cos_(Tensor& self) { return unary_op_impl_(self, at::cos_out); } Tensor& sinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sinh_stub); } @@ -452,8 +452,8 @@ Tensor& tanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out( Tensor tanh(const Tensor& self) { return unary_op_impl(self, at::tanh_out); } Tensor& tanh_(Tensor& self) { return unary_op_impl_(self, at::tanh_out); } -Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, tan_stub); } -Tensor tan(const Tensor& self) { return unary_op_impl(self, at::tan_out); } +Tensor& tan_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, tan_stub); } +Tensor tan(const Tensor& self) { return unary_op_impl_float(self, tan_stub); } Tensor& tan_(Tensor& self) { return unary_op_impl_(self, at::tan_out); } Tensor& trunc_out(Tensor& result, const Tensor& self) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 4e2d00e8347b..652f3ee063e1 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -766,6 +766,19 @@ void hypot_kernel(TensorIterator& iter) { }); } +void igamma_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "igamma_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }, + [=](Vec256 a, Vec256 b) { + return a.igamma(b); + }); + }); +} + void nextafter_kernel(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nextafter_cpu", [&]() { cpu_kernel_vec( @@ -824,6 +837,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel); REGISTER_DISPATCH(gcd_stub, &gcd_kernel); REGISTER_DISPATCH(lcm_stub, &lcm_kernel); REGISTER_DISPATCH(hypot_stub, &hypot_kernel); +REGISTER_DISPATCH(igamma_stub, &igamma_kernel); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel); diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h index 6fe825bcde1e..d4b6da57111d 100644 --- a/aten/src/ATen/native/cpu/DistributionTemplates.h +++ b/aten/src/ATen/native/cpu/DistributionTemplates.h @@ -180,9 +180,7 @@ void normal_kernel(Tensor& self, double mean, double std, RNG generator) { normal_fill(self, static_cast(mean), static_cast(std), generator); #endif } else { - // bfloat16 cannot be properly tested due to the lack of other operations - // like add/sub/mean implemented for half - AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "normal_kernel_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { if (size >= 16 && self.is_contiguous()) { normal_fill(self, static_cast(mean), static_cast(std), generator); } else { @@ -208,7 +206,7 @@ struct NormalKernel { template void uniform_kernel(TensorIterator& iter, double from_, double to_, RNG generator) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { std::lock_guard lock(generator->mutex_); auto from = static_cast(from_); auto to = static_cast(to_); @@ -230,7 +228,7 @@ struct UniformKernel { template void cauchy_kernel(TensorIterator& iter, double median, double sigma, RNG generator) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { std::lock_guard lock(generator->mutex_); at::cauchy_distribution cauchy(median, sigma); cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { diff --git a/aten/src/ATen/native/cpu/MaxPooling.cpp b/aten/src/ATen/native/cpu/MaxPooling.cpp index 35575091dcdb..3741d06a9bf5 100644 --- a/aten/src/ATen/native/cpu/MaxPooling.cpp +++ b/aten/src/ATen/native/cpu/MaxPooling.cpp @@ -30,8 +30,9 @@ void max_pool1d_impl( const Tensor& input, const PoolingParams1D& p) { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool1d_impl", [&] { + const Tensor in = input.contiguous(); scalar_t* const OP = output.data_ptr(); - const scalar_t* const IP = input.contiguous().data_ptr(); + const scalar_t* const IP = in.data_ptr(); // Value used for padding constexpr scalar_t FILL = std::numeric_limits::has_infinity diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index b6d38ce36bc0..d56582467894 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -181,7 +181,7 @@ static void norm_kernel_tensor_iterator_impl( if (val == 0) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormZeroOps(), @@ -189,7 +189,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == 1) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormOneOps(), @@ -197,7 +197,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == 2) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormTwoOps(), @@ -205,7 +205,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, AbsMaxOps(), @@ -213,7 +213,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else if (val == -INFINITY) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, AbsMinOps(), @@ -221,7 +221,7 @@ static void norm_kernel_tensor_iterator_impl( ); }); } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, iter.dtype(), "norm_cpu", [&] { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "norm_cpu", [&] { binary_kernel_reduce( iter, NormOps { scalar_t(val) }, diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 88bad0a919b2..318185e43e8a 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -840,12 +840,13 @@ AT_ERROR("solve: MAGMA library not found in " auto b_data = b.data_ptr(); magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + magma_int_t lda = std::max(magma_int_t{1}, n); if (b.dim() == 2) { auto ipiv = at::empty({n}, at::kInt); magma_int_t info = 0; - magmaSolve(n, nrhs, A_data, n, ipiv.data_ptr(), - b_data, n, &info); + magmaSolve(n, nrhs, A_data, lda, ipiv.data_ptr(), + b_data, lda, &info); infos[0] = info; } else { auto A_mat_stride = matrixStride(A); @@ -885,7 +886,7 @@ AT_ERROR("solve: MAGMA library not found in " magma_int_t* info_array_cur = &info_array[mini_idx]; magmaSolveBatched( - n, nrhs, A_array_cur, n, ipiv_array_cur, b_array_cur, n, + n, nrhs, A_array_cur, lda, ipiv_array_cur, b_array_cur, lda, info_array_cur, batch_limit, magma_queue); } @@ -893,7 +894,7 @@ AT_ERROR("solve: MAGMA library not found in " // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaSolveBatched( - n, nrhs, &A_array[mini_idx], n, &ipiv_array[mini_idx], &b_array[mini_idx], n, + n, nrhs, &A_array[mini_idx], lda, &ipiv_array[mini_idx], &b_array[mini_idx], lda, &info_array[mini_idx], batch_size % batch_limit, magma_queue); } diff --git a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu index 12b1a0dbf305..2f53c2bb08d7 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu @@ -92,6 +92,14 @@ void hypot_kernel_cuda(TensorIterator& iter) { }); } +void igamma_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "igamma_cuda", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return calc_igamma(a, b); + }); + }); +} + void nextafter_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "nextafter_cuda", [&]() { gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { @@ -116,6 +124,7 @@ REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda); REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda); REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda); REGISTER_DISPATCH(hypot_stub, &hypot_kernel_cuda); +REGISTER_DISPATCH(igamma_stub, &igamma_kernel_cuda); REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel_cuda); REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu index be3f4f0bb01e..f80d0906dfa2 100644 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu @@ -65,7 +65,7 @@ void div_kernel_cuda(TensorIterator& iter) { } void mul_kernel_cuda(TensorIterator& iter) { - if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && + if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) { //if common dtype is half the scalar constant can overflow in half precision, and yet the result can //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu index 3e0e70c01952..d0b8c40ee4dc 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool2d.cu @@ -73,13 +73,13 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom template C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS) -__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, +__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, - const int in_stride_n, const int in_stride_c, + const int in_stride_n, const int in_stride_c, const int in_stride_h, const int in_stride_w, const int kernel_stride_C, const int kernel_size_C, scalar_t* top_data, int64_t* top_mask) { @@ -100,9 +100,9 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba __syncthreads(); - int batch_id = blockIdx.x % nbatch; - int channel_id = blockIdx.x / nbatch; - int channel_offset = threadIdx.x + channel_id * blockDim.x; + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; top_data = top_data + batch_id * pooled_height * pooled_width * channels; top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; @@ -130,7 +130,7 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba wstart += dilation_w; for (int ih = hstart; ih < hend; ih++) { for (int iw = wstart; iw < wend; iw++) { - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { scalar_t val = ptr_input[c*in_stride_c]; @@ -138,20 +138,20 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba out_cached[cached_index] = scalar_cast(val); out_mask_cached[cached_index] = ih * width + iw; } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels; int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { ptr_output_data[c] = out_cached[cached_index]; ptr_output_mask[c] = out_mask_cached[cached_index]; out_cached[cached_index] = at::numeric_limits::lower_bound(); out_mask_cached[cached_index] = 0; - cached_index += blockDim.x; + cached_index += blockDim.x; } } } @@ -206,9 +206,9 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int out_stride_c, const int out_stride_h, const int out_stride_w, - const int in_stride_n, const int in_stride_c, + const int in_stride_n, const int in_stride_c, const int in_stride_h, const int in_stride_w, - const int kernel_stride_C, const int kernel_size_C, + const int kernel_stride_C, const int kernel_size_C, scalar_t* bottom_diff) { extern __shared__ int smem[]; accscalar_t *out_cached = reinterpret_cast(smem); @@ -216,9 +216,9 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); int block_size = blockDim.x * blockDim.y * blockDim.z; - int batch_id = blockIdx.x % nbatch; - int channel_id = blockIdx.x / nbatch; - int channel_offset = threadIdx.x + channel_id * blockDim.x; + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { out_cached[i] = accscalar_t(0.0); @@ -245,38 +245,38 @@ __global__ void max_pool_backward_nhwc(const int nthreads, const scalar_t* top_d for (int iw = istartW; iw < iendW; iw+=blockDim.y) { int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w); int pwend = p_end(iw, pad_w, pooled_width, stride_w); - int index_shift = ih * width + iw; + int index_shift = ih * width + iw; if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) { for(int oh = phstart; oh < phend; ++oh) { for(int ow = pwstart; ow < pwend; ++ow) { - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { if (ptr_top_mask[c*out_stride_c] == index_shift) { - out_cached[cached_index] += + out_cached[cached_index] += scalar_cast(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]); } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { ptr_bottom_diff[c] = scalar_cast(out_cached[cached_index]); out_cached[cached_index] = accscalar_t(0.0); - cached_index += blockDim.x; + cached_index += blockDim.x; } } else { const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w; scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; - int cached_index = threadIdx.x; + int cached_index = threadIdx.x; for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { if (ptr_top_mask[c*out_stride_c] == index_shift) { - ptr_bottom_diff[c] = + ptr_bottom_diff[c] = scalar_cast(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]); } - cached_index += blockDim.x; + cached_index += blockDim.x; } } } @@ -388,9 +388,9 @@ void max_pool2d_with_indices_out_cuda_template( const dim3 block(block_x, block_y, block_z); int kernel_stride_C = cuda::ATenCeilDiv( - safe_downcast(nInputPlane), block_x * 4); + safe_downcast(nInputPlane), block_x * 4); int kernel_size_C = cuda::ATenCeilDiv( - safe_downcast(nInputPlane), block_x * kernel_stride_C); + safe_downcast(nInputPlane), block_x * kernel_stride_C); int grid_x = nbatch*kernel_stride_C; int grid_y = std::min( @@ -402,17 +402,18 @@ void max_pool2d_with_indices_out_cuda_template( const dim3 grid(grid_x, grid_y, grid_z); size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); - AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); + AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); max_pool_forward_nhwc <<>>( - input_data, nbatch, + input_data, nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, - in_stride_n, in_stride_c, + in_stride_n, in_stride_c, in_stride_h, in_stride_w, - kernel_stride_C, kernel_size_C, + kernel_stride_C, kernel_size_C, output_data, indices_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { @@ -424,6 +425,7 @@ void max_pool2d_with_indices_out_cuda_template( nbatch, nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, output_data, indices_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); @@ -431,8 +433,6 @@ void max_pool2d_with_indices_out_cuda_template( } ); - AT_CUDA_CHECK(cudaGetLastError()); - if(input.ndimension() == 3) { output.resize_({nInputPlane, outputHeight, outputWidth}); indices.resize_({nInputPlane, outputHeight, outputWidth}); @@ -565,11 +565,11 @@ void max_pool2d_with_indices_backward_out_cuda_template( const dim3 grid(grid_x, grid_y, grid_z); size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t); - AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); + AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock); - // The backward kernel is launched on input instead output. - // If it is launched on output layer, atomic_add would not provide much benefit on FP16. - // Please check comments at https://github.com/pytorch/pytorch/pull/34519. + // The backward kernel is launched on input instead output. + // If it is launched on output layer, atomic_add would not provide much benefit on FP16. + // Please check comments at https://github.com/pytorch/pytorch/pull/34519. max_pool_backward_nhwc <<>>( count, @@ -579,10 +579,11 @@ void max_pool2d_with_indices_backward_out_cuda_template( nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, out_stride_c, out_stride_h, out_stride_w, - in_stride_n, in_stride_c, + in_stride_n, in_stride_c, in_stride_h, in_stride_w, - kernel_stride_C, kernel_size_C, + kernel_stride_C, kernel_size_C, gradInput_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { @@ -606,14 +607,13 @@ void max_pool2d_with_indices_backward_out_cuda_template( nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, dH, dW, padH, padW, dilationH, dilationW, gradInput_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } } ); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index 9d72e0027007..bbafebacbf13 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -112,8 +112,7 @@ void max_pool3d_with_indices_out_frame( pT, pH, pW, dilationT, dilationH, dilationW, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; @@ -178,8 +177,7 @@ void max_pool3d_with_indices_backward_out_frame( pT, pH, pW, dilationT, dilationH, dilationW, offsetZ); - - AT_CUDA_CHECK(cudaGetLastError()); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); totalZ -= 65535; offsetZ += 65535; diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index c43a2ae9877e..60a2c943e742 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -231,17 +231,21 @@ void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, doubl AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] { if (p == 0.0) { cdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 1.0) { cdist_kernel_cuda_impl::one><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { cdist_kernel_cuda_impl::two><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { cdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { cdist_kernel_cuda_impl::p><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); } void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { @@ -257,17 +261,21 @@ void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] { if (p == 0.0) { pdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 1.0) { pdist_kernel_cuda_impl::one><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { pdist_kernel_cuda_impl::two><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { pdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { pdist_kernel_cuda_impl::p><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); } void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) { @@ -295,17 +303,21 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] { if (p == 1.0) { pdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p < 2.0) { pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { pdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { pdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); at::sum_out(result, buffer, 0); } @@ -342,25 +354,29 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor cdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p < 2.0) { cdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (p == 2.0) { cdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else if (std::isinf(p)) { cdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { cdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), gs, p, r1, r2, m, count, r_size, l1_size, l2_size); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); if (x1.dim() > 2) { at::sum_out(result, buffer, 1); diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 2016e96c9fd8..c417a7ccfabd 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -235,18 +235,22 @@ fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ switch (vec_size) { case 4: fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } } else { switch (self_info.dims) { case 1: fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; default: fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } } }); @@ -269,24 +273,27 @@ fused_dropout_cuda(const Tensor& self, double p, c10::optional gen_){ switch (vec_size) { case 4: fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: fused_dropout_kernel_vec<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; } } else { switch (self_info.dims) { case 1: fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); break; default: fused_dropout_kernel<<>>(self_info, ret_info, mask_info, nelem, pa, rng_engine_inputs); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } } }); }); } - AT_CUDA_CHECK(cudaGetLastError()); return std::tuple(ret, mask); } diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index dbf968084e6e..cab8483093df 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -261,10 +261,9 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice static_cast(num_indices), static_cast(stride), static_cast(padding_idx)); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); }); - - AT_CUDA_CHECK(cudaGetLastError()); return grad_weight; } @@ -365,10 +364,9 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, static_cast(max_norm), static_cast(norm_type), dim, self.stride(0), self.stride(1)); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); }); - AT_CUDA_CHECK(cudaGetLastError()); - return self; } diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index eec428ae2a12..ddf3679b4c27 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -54,12 +54,12 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { a = q; i = 0; b = 0.0; - while((i < 9) || (a <= 9.0)){ + while ((i < 9) || (a <= 9.0)) { i += 1; a += 1.0; b = ::pow( a, -x ); s += b; - if((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { + if ((-MACHEP < (b / s)) && ((b / s) < MACHEP)) { return static_cast(s); } }; @@ -68,16 +68,16 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { s -= 0.5 * b; a = 1.0; k = 0.0; - for(int i=0; i < 12; i++) { + for (int i=0; i < 12; i++) { a *= x + k; b /= w; t = a * b / A[i]; s = s + t; t = t / s; - if(t < 0){ + if (t < 0){ t = -t; } - if((-MACHEP (s); } k += 1.0; @@ -174,6 +174,503 @@ static inline __host__ __device__ scalar_t calc_polygamma(int n, scalar_t x) { return ((n % 2) ? 1.0 : -1.0) * ::exp(::lgamma(static_cast(n) + 1.0)) * zeta(static_cast(n + 1), x); } +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +// regularized lower & upper incomplete gamma +template +static __host__ __device__ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + using accscalar_t = at::acc_type; + int64_t i, dir; + accscalar_t y, num_ans, denom_ans; + accscalar_t absx = ::fabs(x); + const accscalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return ::pow(x, static_cast(i)) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +template +static __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + using accscalar_t = at::acc_type; + + static const accscalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0 + }; + return ratevl(static_cast(x), lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + using accscalar_t = at::acc_type; + accscalar_t ax, fac, res, num, numfac; + static accscalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static accscalar_t EXP1 = 2.718281828459045; + static accscalar_t lanczos_g = 6.024680040776729583740234375; + + if (::fabs(a - x) > 0.4 * ::fabs(a)) { + ax = a * ::log(x) - x - ::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return ::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= ::exp(a - x) * ::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + + using accscalar_t = at::acc_type; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + accscalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + using accscalar_t = at::acc_type; + int n; + accscalar_t fac = 1; + accscalar_t sum = 0; + accscalar_t term, logx; + static accscalar_t MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (::fabs(term) <= MACHEP * ::fabs(sum)) { + break; + } + } + + logx = ::log(x); + term = -::expm1(a * logx - ::lgamma(1+a)); + return term - ::exp(a * logx - ::lgamma(a)) * sum; +} + +template +static __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + + using accscalar_t = at::acc_type; + static const accscalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + accscalar_t lambda = x / a; + accscalar_t sigma = (x - a) / a; + accscalar_t eta, res, ck, ckterm, term, absterm; + accscalar_t absoldterm = INFINITY; + accscalar_t etapow[25] = {1}; + accscalar_t sum = 0; + accscalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = ::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -::sqrt(-2 * (::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a); + + return res; +} + +template +static __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + + using accscalar_t = at::acc_type; + int i; + accscalar_t ans, ax, c, yc, r, t, y, z; + accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static accscalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static accscalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static accscalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = ::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static inline __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (::isinf(static_cast(x))) { + return 0.0; + } + + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / ::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static inline __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + + using accscalar_t = at::acc_type; + accscalar_t absxma_a; + static accscalar_t SMALL = 20.0; + static accscalar_t LARGE = 200.0; + static accscalar_t SMALLRATIO = 0.3; + static accscalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (::isinf(static_cast(a))) { + if (::isinf(static_cast(x))) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (::isinf(static_cast(x))) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. */ + absxma_a = ::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +// end of regularized lower & upper incomplete gamma template static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { diff --git a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu index 2ad6f0785a17..522e3bbd8760 100644 --- a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu +++ b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu @@ -188,10 +188,8 @@ void slow_conv_dilated_all_cuda_template( int64_t nInputPlane = weight.size(1); int64_t nOutputPlane = weight.size(0); // Temporary buffers: - int64_t m = std::accumulate( - kernel_size.begin(), kernel_size.end(), 1, std::multiplies()); - int64_t output_vsize = std::accumulate( - output_size.begin(), output_size.end(), 1, std::multiplies()); + const int64_t m = prod_intlist(kernel_size); + const int64_t output_vsize = prod_intlist(output_size); Tensor columns = at::empty({0}, options); if (output.defined() || grad_weight.defined() || grad_input.defined()) { columns.resize_({nInputPlane * m, output_vsize}); diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu index 6bc2c381e1db..b0dc71c568ba 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cu +++ b/aten/src/ATen/native/cuda/ScanKernels.cu @@ -128,16 +128,16 @@ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *se */ template __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_, - int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) { - for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *self = self_ + orow * row_size * num_irows + irow; scalar_t *values = values_ + orow * row_size * num_irows + irow; int64_t *indices = indices_ + orow * row_size * num_irows + irow; scalar_t out = init; int64_t out_idx = 0; - for (int64_t col = 0; col < row_size; ++col) { + for (auto col = decltype(row_size){0}; col < row_size; ++col) { if(THCNumerics::isnan(*self) || (!THCNumerics::isnan(out) && binary_op(*self, out))) { out = *self; out_idx = col; @@ -152,21 +152,34 @@ __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scala } } +void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + template __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, int dim, scalar_t init, BinaryFunction binary_op) { - int row_size = self.size(dim); + int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + dim3 threads(std::min(512, int(num_irows))); - int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x)))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); tensor_kernel_scan_outer_dim_with_indices<<>>( self.data_ptr(), values.data_ptr(), indices.data_ptr(), num_orows, num_irows, row_size, init, binary_op); @@ -254,16 +267,16 @@ void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int */ template __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - scalar_t init, BinaryOp binary_op) + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) { - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { scalar_t *src = src_ + orow * row_size * num_irows + irow; scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; scalar_t acc = init; - for (unsigned col = 0; col < row_size; ++col) { + for (uint32_t col = 0; col < row_size; ++col) { acc = binary_op(acc, *src); *tgt = acc; @@ -286,12 +299,12 @@ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_, */ template __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_, - unsigned num_rows, unsigned row_size, + const uint32_t num_rows, const uint32_t row_size, T init, BinaryFunction binary_op){ - for (unsigned block_row = blockIdx.x * blockDim.y; + for (uint32_t block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; + uint32_t row = block_row + threadIdx.y; T block_total = init; T *row_src = src_ + row * row_size; @@ -299,10 +312,10 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr // Perform scan on one block at a time, keeping track of the total value of // all blocks processed so far. - for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { // Load data into shared memory (two values per thread). - unsigned col1 = block_col + threadIdx.x; - unsigned col2 = block_col + num_threads_x + threadIdx.x; + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; if (row < num_rows) { if (col1 < row_size) { row_buf[threadIdx.x] = row_src[col1]; @@ -324,18 +337,18 @@ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *sr __syncthreads(); // Parallel reduction (up-sweep). - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { if (row < num_rows && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; + uint32_t offset = (2 * threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); } // Down-sweep. - for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { if (row < num_rows && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + uint32_t offset = 2 * (threadIdx.x + 1) * d - 1; row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); } __syncthreads(); @@ -361,8 +374,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { __shared__ T sbuf[num_threads_y][2 * num_threads_x]; @@ -381,8 +394,8 @@ __global__ typename std::enable_if::value, void>::type tensor_kernel_scan_innermost_dim( T* tgt_, T* src_, - unsigned num_rows, - unsigned row_size, + const uint32_t num_rows, + const uint32_t row_size, T init, BinaryFunction binary_op) { // As we cannot directly initialize shared array for complex types @@ -399,23 +412,18 @@ tensor_kernel_scan_innermost_dim( row_buf, tgt_, src_, num_rows, row_size, init, binary_op); } -void check_fits_in_unsigned(int64_t val, const char* name) { - constexpr auto umax = std::numeric_limits::max(); - TORCH_CHECK( - val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value"); -} template __host__ void scan_outer_dim(const Tensor& self, Tensor& result, int dim, scalar_t init, BinaryFunction binary_op) { - int64_t row_size = self.size(dim); + const int64_t row_size = self.size(dim); auto sizes = self.sizes(); // Treat all outer dimensions (i.e. dim_ < dim) as one. - int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies()); + const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim); // Treat all inner dimensions (i.e. dim > dimension) as one. - int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies()); + const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end()); dim3 threads(std::min(512, int(num_irows))); int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 5da007507905..64af6cb268a2 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -237,7 +237,12 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } stackInputs[batchCounter].input = inputs[i+batchCounter].data_ptr(); @@ -338,7 +343,12 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension, batchCounter < CAT_ARRAY_BATCH_SIZE && (i+batchCounter) < inputs.size(); ++batchCounter) { - int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension); + int64_t dimSize = 0; + // There is a legacy case where a 1-D empty tensor can be concat with + // high-dimensional tensor + if (inputs[i+batchCounter].numel() > 0) { + dimSize = at::native::size(inputs[i+batchCounter], dimension); + } catMetaData.input[batchCounter] = inputs[i+batchCounter].data_ptr(); catMetaData.offset[batchCounter] = offset; catMetaData.dimSize[batchCounter] = dimSize; @@ -431,7 +441,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { auto should_skip = [](const Tensor &t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - bool hasSkippedInput = false; const Tensor *notSkippedTensor = NULL; // non-owning reference int nDims = 0; @@ -452,10 +461,8 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { } at::assert_no_internal_overlap(out); - for (int i = 0; i < inputs.size(); i++) - { + for (int i = 0; i < inputs.size(); i++) { if (should_skip(inputs[i])) { - hasSkippedInput = true; continue; } nDims = inputs[i].dim(); @@ -501,11 +508,10 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { // We parallelize the copy if all 6 conditions pass: // // 1. There is more than one input tensor - // 2. No empty inputs - // 3. The out tensor is 32-bit indexable - // 4. The number of dimensions is <= 4 - // 5. All input tensors are contiguous (output tensor may be non-contig) - // 6. All input tensors can use 32-bit indexing + // 2. The out tensor is 32-bit indexable + // 3. The number of dimensions is <= 4 + // 4. All input tensors are contiguous (output tensor may be non-contig) + // 5. All input tensors can use 32-bit indexing const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(), [] (const Tensor& t) { @@ -522,7 +528,6 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) { }); allSameType = allSameType && (out.scalar_type() == firstType); if (inputs.size() > 1 && - !hasSkippedInput && out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && at::cuda::detail::canUse32BitIndexMath(out) && allContiguous && diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 0a5760580c06..c6688b286914 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -319,7 +319,7 @@ std::tuple median_with_indices_impl( NoNamesGuard guard; dim = at::maybe_wrap_dim(dim, self.dim()); - Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0); int64_t size = in.size(dim); TORCH_CHECK( diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh index 54513955e912..0e5cb7371d58 100644 --- a/aten/src/ATen/native/cuda/SortingCommon.cuh +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) { } +// WARNING: This function assumes input tensors are contiguous template void run_launcher( Tensor& values, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index fb0eb4ca8b09..13f0b53516de 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -22,15 +22,13 @@ namespace at { namespace native { Tensor& eye_out_cuda(Tensor& result, int64_t n) { - return at::native::eye_out_cuda(result, n, /*m=*/-1); + // the default value of `m` equals to `n` + return at::native::eye_out_cuda(result, n, n); } Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); - - if(m < 0) { - m = n; - } + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); result.zero_(); diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index 4318b35c1295..147f7f3fad6f 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -95,6 +95,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { kernel_pointwise_flip_apply2 <<>>( in_tensor_info, out_tensor_info, N, flip_dim, total_dims); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; } @@ -131,6 +132,7 @@ Tensor flip_cuda(const Tensor& self, IntArrayRef dims) { stride_contiguous.cuda().data_ptr(), shape_t.cuda().data_ptr(), total_dims); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; @@ -195,6 +197,7 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { size, in_tensor.stride(dim), total_dims); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); return out_tensor; diff --git a/aten/src/ATen/native/cuda/TriangularOps.cu b/aten/src/ATen/native/cuda/TriangularOps.cu index bb17233b3866..d9ba6d09b72f 100644 --- a/aten/src/ATen/native/cuda/TriangularOps.cu +++ b/aten/src/ATen/native/cuda/TriangularOps.cu @@ -67,15 +67,16 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c triu_tril_kernel <<>>( result_info, self_info, k, N); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { auto result_info = cuda::detail::getTensorInfo(result); auto self_info = cuda::detail::getTensorInfo(self); triu_tril_kernel <<>>( result_info, self_info, k, N); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } }); - AT_CUDA_CHECK(cudaGetLastError()); return result; } @@ -191,6 +192,7 @@ Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { sz, self_stride_0 + self_stride_1, result_stride); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } } else { auto n_elems = self.numel(); @@ -219,6 +221,7 @@ Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { n_elems, result_stride_0 + result_stride_1, self_stride); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } } diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu index 465e54db51d6..2f7e92f3fc2e 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu @@ -43,7 +43,7 @@ void sin_kernel_cuda(TensorIterator& iter) { } void cos_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "cos_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "cos_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); @@ -99,7 +99,7 @@ void atanh_kernel_cuda(TensorIterator& iter) { } void tan_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "tan_cuda", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.common_dtype(), "tan_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); diff --git a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu index d7fd3e924b49..1884b09a4fab 100644 --- a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu @@ -41,11 +41,11 @@ static void _launch_unfold_backward_kernel(int total_n_elems, func_t f) { dim3 block(n_threads); constexpr int total_work_block = n_threads * n_elems_per_thread; dim3 grid((total_n_elems + total_work_block - 1) / total_work_block); - + auto stream = at::cuda::getCurrentCUDAStream(); _unfold_backward_elementwise_kernel <<>>(total_n_elems, f); - AT_CUDA_CHECK(cudaGetLastError()); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } template diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu index fa4cde69e499..53af1d463606 100644 --- a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu @@ -228,9 +228,8 @@ static void upsample_bicubic2d_out_cuda_template( align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_bicubic2d_backward_out_cuda_template( @@ -303,9 +302,8 @@ static void upsample_bicubic2d_backward_out_cuda_template( 0, stream>>>( num_kernels, rheight, rwidth, align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index d65d6fa5e1b8..248d972bb320 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -213,9 +213,8 @@ static void upsample_bilinear2d_out_cuda_template( 0, stream>>>( num_kernels, rheight, rwidth, align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_bilinear2d_backward_out_cuda_template( @@ -306,9 +305,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu index a81d3e6c78b6..08824565b150 100644 --- a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu @@ -160,9 +160,8 @@ static void upsample_linear1d_out_cuda_template( num_threads, 0, stream>>>(num_kernels, rwidth, align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_linear1d_backward_out_cuda_template( @@ -221,9 +220,8 @@ static void upsample_linear1d_backward_out_cuda_template( num_threads, 0, stream>>>(num_kernels, rwidth, align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu index 08bea73727ea..425d450b375f 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest1d.cu @@ -128,9 +128,8 @@ static void upsample_nearest1d_out_cuda_template( upsample_nearest1d_out_frame<<>>( idata, nbatch, channels, input_width, output_width, odata, scale_factor); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest1d_backward_out_cuda_template( @@ -191,9 +190,8 @@ static void upsample_nearest1d_backward_out_cuda_template( upsample_nearest1d_backward_out_frame <<>>( odata, nbatch, channels, output_width, input_width, idata, scale_factor); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index 49a74f46ee14..a7f935e5f681 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -204,9 +204,8 @@ static void upsample_nearest2d_out_cuda_template( output_width, height_scale, width_scale); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest2d_backward_out_cuda_template( @@ -287,8 +286,8 @@ static void upsample_nearest2d_backward_out_cuda_template( idata, height_scale, width_scale); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index 76f694274f89..820358152351 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -199,9 +199,8 @@ static void upsample_nearest3d_out_cuda_template( depth_scale, height_scale, width_scale); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_nearest3d_backward_out_cuda_template( @@ -292,9 +291,8 @@ static void upsample_nearest3d_backward_out_cuda_template( depth_scale, height_scale, width_scale); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu index 0498daa037c9..cf623723eaaa 100644 --- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -271,9 +271,8 @@ static void upsample_trilinear3d_out_cuda_template( align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } static void upsample_trilinear3d_backward_out_cuda_template( @@ -361,9 +360,8 @@ static void upsample_trilinear3d_backward_out_cuda_template( align_corners, idata, odata); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); }); - - AT_CUDA_CHECK(cudaGetLastError()); } } // namespace diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 89751d18891b..6cb9351548fa 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -570,7 +570,7 @@ void GroupNormKernelImplInternal( : cuda_utils::kCUDABlockReduceNumThreads; RowwiseMomentsCUDAKernel<<>>( D * HxW, eps, X_data, mean_data, rstd_data); - AT_CUDA_CHECK(cudaGetLastError()); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); if (HxW == 1) { GroupNorm1dForward(X, mean, rstd, gamma, beta, N, C, G, Y); @@ -604,6 +604,7 @@ void GroupNormKernelImplInternal( const int64_t B = (N * C + kCUDANumThreads - 1) / kCUDANumThreads; ComputeFusedParamsCUDAKernel<<>>( N, C, G, mean_data, rstd_data, gamma_data, beta_data, a_data, b_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); auto iter = TensorIteratorConfig() .check_all_same_dtype(std::is_same::value) .resize_outputs(false) @@ -616,7 +617,6 @@ void GroupNormKernelImplInternal( return a * static_cast(x) + b; }); } - AT_CUDA_CHECK(cudaGetLastError()); } void GroupNormKernelImpl( @@ -698,6 +698,7 @@ void GroupNorm1dBackward( gamma_data, c2_data, c3_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); if (gamma.defined()) { auto iter = TensorIteratorConfig() @@ -753,6 +754,7 @@ void GroupNorm1dBackward( rstd_data, dgamma_data, dbeta_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize; // The algorithm for colwise reduction here is to accumulate each 32 cols @@ -771,8 +773,8 @@ void GroupNorm1dBackward( rstd_data, dgamma_data, dbeta_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } } @@ -835,7 +837,7 @@ void GroupNormBackwardKernelImplInternal( : cuda_utils::kCUDABlockReduceNumThreads; ComputeInternalGradientsCUDAKernel<<>>( HxW, dY_data, X_data, ds_data, db_data); - AT_CUDA_CHECK(cudaGetLastError()); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); if (dX.defined()) { Tensor c1 = at::empty({0}, X.options().dtype(kAccType)); @@ -871,6 +873,7 @@ void GroupNormBackwardKernelImplInternal( db_data, c2_data, c3_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); if (gamma.defined()) { auto iter = TensorIteratorConfig() @@ -905,7 +908,6 @@ void GroupNormBackwardKernelImplInternal( c3; }); } - AT_CUDA_CHECK(cudaGetLastError()); } if (dgamma.defined() || dbeta.defined()) { T* dgamma_data = dgamma.defined() ? dgamma.data_ptr() : nullptr; @@ -923,6 +925,7 @@ void GroupNormBackwardKernelImplInternal( db_data, dgamma_data, dbeta_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } else { const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize; // The algorithm for colwise reduction here is to accumulate each 32 cols @@ -941,8 +944,8 @@ void GroupNormBackwardKernelImplInternal( db_data, dgamma_data, dbeta_data); + TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } - AT_CUDA_CHECK(cudaGetLastError()); } } diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 759167095ae3..3bd4daac917b 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -106,11 +106,8 @@ Tensor group_norm( input.sizes()); const auto input_shape = input.sizes(); - const int64_t HxW = std::accumulate( - input_shape.cbegin() + 2, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t HxW = + prod_intlist(input_shape.cbegin() + 2, input_shape.cend()); const Tensor kEmpty; const auto& X = input.is_contiguous() ? input : input.contiguous(); diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index bf931fb26c5f..fa936ab7d4ce 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -52,16 +52,10 @@ std::tuple _prepare_layer_norm_inputs( } const int axis = input_ndim - normalized_ndim; - const int64_t M = std::accumulate( - input_shape.cbegin(), - input_shape.cbegin() + axis, - 1LL, - std::multiplies()); - const int64_t N = std::accumulate( - input_shape.cbegin() + axis, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t M = + prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); const auto& X = input.is_contiguous() ? input : input.contiguous(); const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm index ba3362174bb6..d9550352b922 100644 --- a/aten/src/ATen/native/metal/MetalAten.mm +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -91,7 +91,7 @@ Tensor empty( optional device, optional pin_memory) { TORCH_CHECK( - !pin_memory.has_value(), + !pin_memory.has_value() || !pin_memory.value(), "'pin_memory' argument is incompatible with Metal tensor"); MetalTensor mt{size.vec(), stride.vec()}; return MetalTensor::toTensor( diff --git a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp index f73b18bea497..89f08e2aa03a 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp +++ b/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp @@ -1,6 +1,7 @@ #include #include -#include +#include + #if defined(C10_IOS) #import diff --git a/aten/src/ATen/native/metal/MetalTensor.mm b/aten/src/ATen/native/metal/MetalTensor.mm index 6dfe3932bf16..b1fc38d92a6b 100644 --- a/aten/src/ATen/native/metal/MetalTensor.mm +++ b/aten/src/ATen/native/metal/MetalTensor.mm @@ -17,7 +17,7 @@ class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl { _numel(std::accumulate( std::begin(_sizes), std::end(_sizes), - 1, + (int64_t)1, std::multiplies())), _textureImpl(std::make_unique(sizes)) {} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b143d39e67e2..ad5846d72812 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -283,13 +283,13 @@ use_c10_dispatcher: full variants: function dispatch: - DefaultBackend: view_as_real + CPU, CUDA: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) use_c10_dispatcher: full variants: function dispatch: - DefaultBackend: view_as_complex + CPU, CUDA: view_as_complex - func: sgn(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1546,6 +1546,16 @@ - func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) use_c10_dispatcher: full +# row_stack is the alias of vstack +- func: row_stack(Tensor[] tensors) -> Tensor + use_c10_dispatcher: full + dispatch: + Math: row_stack + +- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: row_stack_out + - func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -5640,7 +5650,7 @@ use_c10_dispatcher: full variants: method, function dispatch: - CPU: legacy::cpu::_th_trace + CPU: trace_cpu CUDA: trace_cuda - func: trace_backward(Tensor grad, int[] sizes) -> Tensor @@ -6496,6 +6506,21 @@ dispatch: DefaultBackend: hypot_ +- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: igamma_out + +- func: igamma(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: full + variants: method, function + dispatch: + CPU, CUDA: igamma + +- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: igamma_ + - func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: nextafter_out @@ -8649,6 +8674,15 @@ CPU: col2im_backward_cpu CUDA: col2im_backward_cuda +- func: column_stack(Tensor[] tensors) -> Tensor + use_c10_dispatcher: full + dispatch: + Math: column_stack + +- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: column_stack_out + - func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: @@ -8875,6 +8909,18 @@ python_module: linalg variants: function +- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorsolve + +- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + dispatch: + Math: linalg_tensorsolve_out + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp index 11b84fa4713d..a05026d49f46 100644 --- a/aten/src/ATen/native/quantized/TensorFactories.cpp +++ b/aten/src/ATen/native/quantized/TensorFactories.cpp @@ -20,7 +20,7 @@ Tensor empty_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); @@ -42,7 +42,7 @@ Tensor empty_per_channel_affine_quantized( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); - auto options = options_.merge_in(TensorOptions().memory_format(optional_memory_format)); + auto options = options_.merge_memory_format(optional_memory_format); TORCH_CHECK( options.has_dtype(), "Must provide data type for Tensor creation functions."); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 4d95ce4ffb4c..91c895685fd3 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -60,6 +60,30 @@ void CopyToChannelsLast3dTensor( } } +template +void CopyICFirst3dTensorToChannelsLast3dTensor( + int64_t G, + int64_t IC_G, + int64_t OC_G, + int64_t D, + int64_t H, + int64_t W, + const T* src, + T* dst) { + // IC OC/G THW -> G OC/G THW IC/G + const int64_t inner_size = D * H * W; + for (int64_t i = 0; i < G * OC_G; ++i) { + for (int64_t j = 0; j < inner_size; ++j) { + for (int64_t ic = 0; ic < IC_G; ++ic) { + int g = i / OC_G; + int oc = i % OC_G; + dst[(i * inner_size + j) * IC_G + ic] = + src[((g * IC_G + ic) * OC_G + oc) * inner_size + j]; + } + } + } +} + } // namespace template @@ -256,6 +280,75 @@ template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( const std::vector& dilations, const std::vector& output_padding, bool transposed); +template <> +Tensor TransposeConvTensorUnpackConversion<3>(const Tensor& src, int groups) { + // OC IC/G DHW -> IC OC/G DHW logically + auto oc_g_ic_g_hw_tensors = src.chunk(groups); + auto fused_tensor = + at::cat(oc_g_ic_g_hw_tensors, 1).set_quantizer_(src.quantizer()); + return fused_tensor.permute({1, 0, 2, 3, 4}); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<2>( + const at::Tensor& src, + int groups, + bool transpose) { + return transpose ? + // 2D conv transpose weight transform + // IC OC/G KH KW -> G OC/G KH KW IC/G + [&]() { + auto ic_g_oc_g_hw_tensors = src.chunk(groups); + for (auto& tensor : ic_g_oc_g_hw_tensors) { + tensor = tensor.unsqueeze(0); + } + auto fused_tensor = + at::cat(ic_g_oc_g_hw_tensors).set_quantizer_(src.quantizer()); + return fused_tensor.permute({0, 2, 3, 4, 1}) + .contiguous(c10::MemoryFormat::Contiguous); + }() + // 2d conv weight transform + : src.contiguous(c10::MemoryFormat::ChannelsLast); +} + +template <> +Tensor ConvertConvWeightsToChannelLastTensor<3>( + const at::Tensor& src, + int groups, + bool transpose) { + if (!transpose) { + return ConvertToChannelsLast3dTensor(src); + } else { + TORCH_CHECK(src.dim() == 5); + Tensor dst; + const int64_t N = src.size(0); + const int64_t IC_G = N / groups; + const int64_t OC_G = src.size(1); + const int64_t D = src.size(2); + const int64_t H = src.size(3); + const int64_t W = src.size(4); + dst = MakeStridedQTensorCPU( + {groups * OC_G, IC_G, D, H, W}, + {D * H * W * IC_G, 1, H * W * IC_G, W * IC_G, IC_G}, + src.options(), + src.quantizer()); + AT_DISPATCH_QINT_TYPES( + src.scalar_type(), "CopyICFirst3dTensorToChannelsLast3dTensor", [&]() { + const Tensor src_contig = src.contiguous(); + CopyICFirst3dTensorToChannelsLast3dTensor( + groups, + IC_G, + OC_G, + D, + H, + W, + src_contig.data_ptr(), + dst.data_ptr()); + }); + return dst; + } +} + } // namespace fbgemm_utils } // namespace native } // namespace at @@ -263,8 +356,9 @@ template fbgemm::conv_param_t<3> MakeFbgemmConvParam<3>( #endif // USE_FBGEMM -template -CAFFE2_API torch::class_> register_conv_params() { + template + CAFFE2_API torch::class_> + register_conv_params() { static auto register_conv_params = torch::class_>( "quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase") diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 40ef0feba61e..0cccf81e35d8 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -295,6 +295,11 @@ Tensor TransposeConvTensorUnpackConversion( const Tensor& src, int groups); +template +Tensor ConvertConvWeightsToChannelLastTensor( + const at::Tensor& src, + int groups, + bool transpose); } // namespace fbgemm_utils } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 9dd98b71ad40..8137049a75c8 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2568,7 +2568,9 @@ void dequantize_tensor_per_tensor_affine_cpu( #endif // USE_FBGEMM // TODO: add fbgemm for per channel -void quantize_tensor_per_channel_affine_cpu( +// Generic template defaults to naive quantize implementation +template +void quantize_tensor_per_channel_impl( Tensor rtensor, Tensor qtensor, Tensor scales, @@ -2580,47 +2582,253 @@ void quantize_tensor_per_channel_affine_cpu( // Since current implemntation on channels_last format does not // cover per channel quant with arbitrary axis value, it is better // to check and fail. - TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = qtensor.data_ptr(); + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (auto b = 0; b < batches; ++b) { + for (auto e = 0; e < elements_per_channel; ++e) { + for (auto c = 0; c < channels; ++c) { + auto i = b * channels * elements_per_channel + e * channels + c; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } else { + for (auto b = 0; b < batches; ++b) { + for (auto c = 0; c < channels; ++c) { + for (auto e = 0; e < elements_per_channel; ++e) { + auto i = b * channels * elements_per_channel + + c * elements_per_channel + e; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } +} + +#if defined(__ARM_NEON__) || defined(__aarch64__) +// Specialized implementation from caffe2::Int8Quantize. +// There may be slight accuracy difference between this and implementation of +// quantize_val +// TODO Update quantize_tensor_per_channel_impl implementation to follow +// quantize_val, i.e. f = Round(value/scale + zero_point) +// TODO Make quantize_tensor_per_channel_impl work for other datatypes too +// (int8, int32). +template <> +void quantize_tensor_per_channel_impl( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = (uint8_t*)qtensor.data_ptr(); +#if defined(__ARM_NEON__) + // magic float and magic int to take care of rounding + // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000 + // Some detail: + // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you + // add a small number to a large number, the result rounds to the precision of + // the least significant bit of the large number. For IEEE-754 + // single-precision number mantissa has 23 bits, and adding 2**23 would cause + // rounding to the nearest even integer. The we cast to int and subtract the + // same number (0x4B400000 is the integer representation of 12582912.0f) to + // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the + // sign for negative numbers. + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + // Copy reciprocal of scales (double) into float array + // Copy zero_points with magic int (int64_t) into int32_t array + std::vector inv_scales(channels); + std::vector zero_points_int32t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset0123, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset4567, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#else // defined(__ARM_NEON__) + // Copy scales (double) into float array + // Copy zero_points (int64_t) into int16_t array + std::vector inv_scales(channels); + std::vector zero_points_int16t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i]; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#endif // defined(__ARM_NEON__) +} +#endif // defined(__ARM_NEON__) || defined(__aarch64__) + +void quantize_tensor_per_channel_affine_cpu( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + TORCH_CHECK( + rtensor.is_contiguous() || (axis <= 1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); AT_DISPATCH_QINT_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() { - int64_t batches = size_to_dim_(axis, rtensor.sizes()); - int64_t elements_per_channel = - size_from_dim_(axis + 1, rtensor.sizes()); - int64_t channel = rtensor.size(axis); - auto scales_data = scales.data_ptr(); - auto zero_points_data = zero_points.data_ptr(); check_tensor_memory_format(rtensor, qtensor); - const float* rdata = rtensor.data_ptr(); - auto qdata = qtensor.data_ptr(); - if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || - rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { - // This code handles per channel quant when axis = 1 and - // channels_last contig. - // If axis = 0 and channels_last contig, implementation - // for channels first (NCHW) works. - for (auto b = 0; b < batches; ++b) { - for (auto e = 0; e < elements_per_channel; ++e) { - for (auto c = 0; c < channel; ++c) { - auto i = b * channel * elements_per_channel + e * channel + c; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } else { - for (auto b = 0; b < batches; ++b) { - for (auto c = 0; c < channel; ++c) { - for (auto e = 0; e < elements_per_channel; ++e) { - auto i = b * channel * elements_per_channel + - c * elements_per_channel + e; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } + quantize_tensor_per_channel_impl( + rtensor, qtensor, scales, zero_points, axis); }); } diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index e96bd26acaba..05762bfb036f 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -411,7 +411,7 @@ at::Tensor PackedConvWeight::apply_impl( output_shape = MakeDeConvOutputShape( N, M, - {H, W}, + kSpatialDim == 2 ? std::vector{H, W} : std::vector{D, H, W}, kernel, stride(), padding(), @@ -886,6 +886,9 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // transpose m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d"), QConv1dInt8::run); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d"), QConvInt8<2, false>::run); + m.impl( + TORCH_SELECTIVE_NAME("quantized::conv_transpose3d"), + QConvInt8<3, false>::run); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index af400558b5fe..c3b20163502d 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -39,9 +39,6 @@ c10::intrusive_ptr> PackedConvWeight< padding.size() == kSpatialDim, "Specify front/top/left padding only. " "end/bottom/right padding assumed to be equal to front/top/left"); - TORCH_CHECK( - !(transpose && kSpatialDim == 3), - "Currently no support for 3d conv_transpose in FBGEM. "); TORCH_CHECK( !transpose || output_padding.size() == kSpatialDim, "quantized::conv_prepack: Specify top/left output padding " @@ -104,26 +101,10 @@ c10::intrusive_ptr> PackedConvWeight< // for both conv and conv transpose // but PyTorch lays them out as {out_c, in_c/groups, kH, kW} // (or for ConvTranspose {in_c, out_c/groups, kH, kW}) - const at::Tensor weight_nhwc = transpose - ? - // check transpose - // 2D conv transpose weight transform - // IC OC/G KH KW -> OC KH KW IC/G - // transpose does not support 3d yet. - [&]() { - auto ic_g_oc_g_hw_tensors = weight.chunk(groups); - auto fused_tensor = - at::cat(ic_g_oc_g_hw_tensors, 1).set_quantizer_(weight.quantizer()); - return fused_tensor.permute({1, 2, 3, 0}) - .contiguous(c10::MemoryFormat::Contiguous); - }() - : (kSpatialDim == 2 - // 2d conv weight transform - ? weight.contiguous(c10::MemoryFormat::ChannelsLast) - // 3d conv weight transform - : at::native::fbgemm_utils::ConvertToChannelsLast3dTensor(weight)); + const at::Tensor weight_nhwc = + at::native::fbgemm_utils::ConvertConvWeightsToChannelLastTensor(weight, groups, transpose); const int8_t* weight_data_int8 = - reinterpret_cast(weight_nhwc.data_ptr()); + reinterpret_cast(weight_nhwc.data_ptr()); std::vector col_offsets(output_channels); // compute column offsets (Similar to // fbgemm::col_offsets_with_zero_pt_s8acc32_ref) please note that offsets @@ -444,6 +425,7 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { // ConvTranspose m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { @@ -452,6 +434,7 @@ TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { // ConvTranspose m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose1d_prepack"), TORCH_FN(QConv1dPackWeightInt8::run_deconv)); m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose2d_prepack"), TORCH_FN(QConvPackWeightInt8<2>::run_deconv)); + m.impl(TORCH_SELECTIVE_NAME("_quantized::conv_transpose3d_prepack"), TORCH_FN(QConvPackWeightInt8<3>::run_deconv)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index a908a0b77732..484bfe44fc76 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -13,10 +13,6 @@ template std::tuple> PackedConvWeight< kSpatialDim>::unpack() { auto* packed_weights_p = w.get(); - TORCH_CHECK( - !(kSpatialDim != 2 && transpose()), - "FBGEMM does not support 3d unpack right " - "now."); // output channels const int output_channels = packed_weights_p->outputChannels(); const int input_channels = packed_weights_p->inputChannels(); @@ -91,7 +87,7 @@ std::tuple> PackedConvWeight< if(transpose()){ unpacked_weights = at::native::fbgemm_utils::TransposeConvTensorUnpackConversion< - 2>(unpacked_weights, groups); + kSpatialDim>(unpacked_weights, groups); } return std::tuple>( unpacked_weights, bias); @@ -276,6 +272,7 @@ TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { // ConvTranspose is the same, however, we want to have different name. m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose1d_unpack"), TORCH_FN(QConv1dUnpackWeightsInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_unpack"), TORCH_FN(QConvUnpackWeightsInt8<3>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_stride"), TORCH_FN(QConvStride<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_padding"), TORCH_FN(QConvPadding<2>::run)); @@ -283,6 +280,12 @@ TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_dilation"), TORCH_FN(QConvDilation<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_groups"), TORCH_FN(QConvGroups<2>::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose2d_transpose"), TORCH_FN(QConvTranspose<2>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_stride"), TORCH_FN(QConvStride<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_padding"), TORCH_FN(QConvPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_output_padding"), TORCH_FN(QConvOutputPadding<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_dilation"), TORCH_FN(QConvDilation<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_groups"), TORCH_FN(QConvGroups<3>::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::conv_transpose3d_transpose"), TORCH_FN(QConvTranspose<3>::run)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp index 6ed193cd82c9..c8bbe9d29b24 100644 --- a/aten/src/ATen/native/quantized/cpu/qnormalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnormalization.cpp @@ -71,11 +71,8 @@ Tensor quantized_group_norm_impl( const int64_t batches = input_shape[0]; const int64_t num_channels = input_shape[1]; - const int64_t elements_per_batch = std::accumulate( - input_shape.cbegin() + 1, - input_shape.cend(), - 1LL, - std::multiplies()); + const int64_t elements_per_batch = + prod_intlist(input_shape.cbegin() + 1, input_shape.cend()); const int64_t M = batches * num_groups; const int64_t N = elements_per_batch / num_groups; diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 3150fd986300..c09501deec91 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -97,6 +97,7 @@ TORCH_LIBRARY(quantized, m) { // conv_tranpsose m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d(Tensor qx, __torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); @@ -107,6 +108,14 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_output_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose3d_transpose(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::elu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_prepack(Tensor weight) -> __torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack")); @@ -182,6 +191,7 @@ TORCH_LIBRARY(_quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose3d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv3dPackedParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); diff --git a/aten/src/ATen/native/vulkan/Vulkan.cpp b/aten/src/ATen/native/vulkan/Vulkan.cpp index d6fa6a32291b..d58d7d7dcc09 100644 --- a/aten/src/ATen/native/vulkan/Vulkan.cpp +++ b/aten/src/ATen/native/vulkan/Vulkan.cpp @@ -7,6 +7,7 @@ #include #include +#include #ifdef USE_VULKAN_WRAPPER #include @@ -1182,11 +1183,7 @@ class VulkanTensor::Impl final { explicit Impl(std::vector sizes) : sizes_(std::move(sizes)), strides_(std::vector(sizes_.size())), - numel_(std::accumulate( - std::begin(sizes_), - std::end(sizes_), - 1, - std::multiplies())) { + numel_(prod_intlist(sizes_)) { TORCH_CHECK( initVulkanContextOnce(), "Vulkan Failed to create Vulkan Context"); } @@ -1289,8 +1286,7 @@ class VulkanTensor::Impl final { VkDeviceSize buffer_size_for_sizes(std::vector sizes) const { const auto d = sizes.size(); - const auto numel = std::accumulate( - std::begin(sizes), std::end(sizes), 1, std::multiplies()); + const auto numel = prod_intlist(sizes); VkDeviceSize bufferSize{sizeof(float) * numel}; // alignment to be able to copy between image and buffer if (d == 4) { diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 6afc28676f2b..fc4f9945fcaf 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -11,7 +11,7 @@ template struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { VulkanOpaqueTensorImpl( at::DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes, diff --git a/aten/src/ATen/native/vulkan/VulkanOps.cpp b/aten/src/ATen/native/vulkan/VulkanOps.cpp index 302525582c9d..2aff22163071 100644 --- a/aten/src/ATen/native/vulkan/VulkanOps.cpp +++ b/aten/src/ATen/native/vulkan/VulkanOps.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -66,7 +67,7 @@ void upsample_nearest2d( WorkGroupSize workGroupSize{8, 8, 1}; auto& computeUnit = context().computeUnitFactory().get( - GLSL_SPV(upsampleNearest2d), descriptorSetLayout, workGroupSize); + GLSL_SPV(upsample_nearest2d), descriptorSetLayout, workGroupSize); computeUnit.createCommandBuffer(descriptorSet); input.image()->addImageMemoryBarrierToShaderRead(computeUnit.commandBuffer()); computeUnit.dispatchCommandBuffer(OW, OH, C, workGroupSize); @@ -553,21 +554,19 @@ void add( void add(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; + int32_t inputSize[3]; float s; }; ConstBlock cb{{safe_downcast(W), safe_downcast(H), - safe_downcast(C_4), - 0}, + safe_downcast(C_4)}, s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -606,21 +605,19 @@ void add(VulkanTensor& output, const VulkanTensor& input, const float s) { void mul(VulkanTensor& output, const VulkanTensor& input, const float s) { const auto sizes = input.sizes(); - const auto C = std::accumulate( - sizes.cbegin(), sizes.cend() - 2, 1, std::multiplies()); + const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2); const auto C_4 = UP_DIV(C, 4); const auto H = sizes[2]; const auto W = sizes[3]; auto device = context().device(); struct ConstBlock { - int32_t inputSize[4]; + int32_t inputSize[3]; float s; }; ConstBlock cb{{safe_downcast(W), safe_downcast(H), - safe_downcast(C_4), - 0}, + safe_downcast(C_4)}, s}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -1166,14 +1163,14 @@ void clamp( int32_t W; int32_t H; int32_t C_4; - int32_t C; + //int32_t C; float min; float max; }; ConstBlock cb{safe_downcast(W), safe_downcast(H), safe_downcast(C_4), - safe_downcast(C), + //safe_downcast(C), min, max}; VBuffer constBuffer = makeUniformConstBuffer((void*)&cb, sizeof(cb)); @@ -1246,7 +1243,6 @@ void addmm( int32_t OW; int32_t OH; int32_t C_4; - int32_t C; float beta; float alpha; int32_t K; @@ -1254,7 +1250,6 @@ void addmm( ConstBlock cb{safe_downcast(OW), safe_downcast(OH), safe_downcast(C_4), - safe_downcast(C), beta, alpha, safe_downcast(K)}; diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h index fdeadf9cdbfa..b468a1c05c6d 100644 --- a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -4594,7 +4594,7 @@ static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, co size_t down = 0, up = (end - beg); while(down < up) { - const size_t mid = (down + up) / 2; + const size_t mid = down + (up - down) / 2; //Overflow-safe midpoint calculation if(cmp(*(beg+mid), key)) { down = mid + 1; diff --git a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl index ab07da5e4897..a1f3d6f21df9 100644 --- a/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/adaptive_avg_pool2d.glsl @@ -1,26 +1,29 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { int IW; int IH; int OW; int OH; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; + int ow = uBlock.OW; + int oh = uBlock.OH; if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; + int iw = uBlock.IW; + int ih = uBlock.IH; int sx = int(floor(float(pos.x * iw) / ow)); int sy = int(floor(float(pos.y * ih) / oh)); diff --git a/aten/src/ATen/native/vulkan/glsl/add.glsl b/aten/src/ATen/native/vulkan/glsl/add.glsl index 9b7e992e78c5..27e69152ac1b 100644 --- a/aten/src/ATen/native/vulkan/glsl/add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add.glsl @@ -1,27 +1,28 @@ #version 450 core #define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; -layout(set = 0, binding = 3) uniform constBlock { - int W; - int H; - int C; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uInput1; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; float alpha; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 WHC = ivec3(uConstBlock.W, uConstBlock.H, uConstBlock.C); - if (all(lessThan(pos, WHC))) { - vec4 v = texelFetch(uInput0, pos, 0) + - uConstBlock.alpha * texelFetch(uInput1, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput0, pos, 0) + uBlock.alpha * texelFetch(uInput1, pos, 0)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_.glsl b/aten/src/ATen/native/vulkan/glsl/add_.glsl new file mode 100644 index 000000000000..c872a8193ca3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_.glsl @@ -0,0 +1,27 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput0; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; + float alpha; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.alpha * texelFetch(uInput0, pos, 0)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl index 559cdd7441c3..10a95330a48c 100644 --- a/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar.glsl @@ -1,21 +1,27 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; float other; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = texelFetch(uInput, pos, 0) + uConstBlock.other; - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) + uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl new file mode 100644 index 000000000000..8e736e2a6a71 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/add_scalar_.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform restrict Block { + ivec3 WHC; + float other; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) + uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/addmm.glsl b/aten/src/ATen/native/vulkan/glsl/addmm.glsl index 79987990e595..55fa3da02e0b 100644 --- a/aten/src/ATen/native/vulkan/glsl/addmm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/addmm.glsl @@ -2,24 +2,26 @@ #define PRECISION $precision layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; float beta; float alpha; int K; -} -uConstBlock; -layout(set = 0, binding = 4) uniform PRECISION sampler3D uT; +} uBlock; +layout(set = 0, binding = 4) uniform PRECISION sampler3D uT; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { + const int K = uBlock.K; vec4 mmv = vec4(0); int ki = 0; for (; ki < K; ++ki) { @@ -28,6 +30,6 @@ void main() { mmv += m1ki * m2ki; } vec4 tv = texelFetch(uT, pos, 0); - imageStore(uOutput, pos, uConstBlock.beta * tv + uConstBlock.alpha * mmv); + imageStore(uOutput, pos, uBlock.beta * tv + uBlock.alpha * mmv); } } diff --git a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl index 552e75c11d59..7de1455a9051 100644 --- a/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl @@ -1,17 +1,21 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { ivec4 inputSize; ivec4 outputSize; ivec2 kernelSize; ivec2 stride; ivec2 padding; ivec2 dilate; -} -uConstBlock; +} uBlock; #define UP_DIV(x, y) (((x) + (y)-1) / (y)) @@ -19,13 +23,10 @@ layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); - ivec3 outputSize = uConstBlock.outputSize.xyz; - if (all(lessThan(pos, outputSize))) { - ivec2 s0 = pos.xy * uConstBlock.stride - uConstBlock.padding; - ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uConstBlock.dilate))); - ivec2 efxy = - min(uConstBlock.kernelSize, - UP_DIV(uConstBlock.inputSize.xy - s0, uConstBlock.dilate)); + if (all(lessThan(pos, uBlock.outputSize.xyz))) { + ivec2 s0 = pos.xy * uBlock.stride - uBlock.padding; + ivec2 sfxy = max(ivec2(0), (UP_DIV(-s0, uBlock.dilate))); + ivec2 efxy = min(uBlock.kernelSize, UP_DIV(uBlock.inputSize.xy - s0, uBlock.dilate)); vec4 r = vec4(1.0) / float(efxy.x - sfxy.x) / float(efxy.x - sfxy.x); vec4 acc = vec4(0); diff --git a/aten/src/ATen/native/vulkan/glsl/clamp.glsl b/aten/src/ATen/native/vulkan/glsl/clamp.glsl index 24104c2285a1..25caddefd037 100644 --- a/aten/src/ATen/native/vulkan/glsl/clamp.glsl +++ b/aten/src/ATen/native/vulkan/glsl/clamp.glsl @@ -2,22 +2,22 @@ #define PRECISION $precision layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 size; +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; float minValue; float maxValue; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.size.xyz))) { - vec4 v = texelFetch(uInput, pos, 0); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { imageStore( - uOutput, pos, clamp(v, uConstBlock.minValue, uConstBlock.maxValue)); + uOutput, + pos, + clamp(texelFetch(uInput, pos, 0), uBlock.minValue, uBlock.maxValue)); } } diff --git a/aten/src/ATen/native/vulkan/glsl/clamp_.glsl b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl new file mode 100644 index 000000000000..c7c6e0d61ba1 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/clamp_.glsl @@ -0,0 +1,22 @@ +#version 450 core +#define PRECISION $precision +layout(std430) buffer; +layout(std430) uniform; +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform restrict Block { + ivec3 WHC; + float minValue; + float maxValue; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + clamp(imageLoad(uOutput, pos), uBlock.minValue, uBlock.maxValue)); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/mm.glsl b/aten/src/ATen/native/vulkan/glsl/mm.glsl index 771617d64b8a..2d39b28802e5 100644 --- a/aten/src/ATen/native/vulkan/glsl/mm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mm.glsl @@ -2,23 +2,23 @@ #define PRECISION $precision layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; -layout(set = 0, binding = 3) uniform constBlock { - ivec4 outputSize; - float beta; - float alpha; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uM1; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uM2; +layout(set = 0, binding = 3) uniform restrict Block { + ivec3 WHC; int K; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.outputSize.xyz))) { - int K = uConstBlock.K; + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.WHC))) { + const int K = uBlock.K; vec4 mmv = vec4(0); int ki = 0; for (; ki < K; ++ki) { @@ -26,6 +26,6 @@ void main() { vec4 m2ki = texelFetch(uM2, ivec3(pos.x, ki, pos.z), 0); mmv += m1ki * m2ki; } - imageStore(uOutput, pos, uConstBlock.alpha * mmv); + imageStore(uOutput, pos, mmv); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl index d34a99d2c6e8..8d7fc2b93a7f 100644 --- a/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar.glsl @@ -1,21 +1,27 @@ #version 450 core +#define PRECISION $precision + layout(std430) buffer; layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly highp uniform image3D uOutput; -layout(set = 0, binding = 1) uniform highp sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - ivec4 sizes; +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + ivec3 WHC; float other; -} -uConstBlock; +} uBlock; layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uConstBlock.sizes.xyz))) { - vec4 v = uConstBlock.other * texelFetch(uInput, pos, 0); - imageStore(uOutput, pos, v); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + texelFetch(uInput, pos, 0) * uBlock.other); } } diff --git a/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl new file mode 100644 index 000000000000..9d1626a2ba83 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/mul_scalar_.glsl @@ -0,0 +1,26 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform restrict Block { + ivec3 WHC; + float other; +} uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.WHC))) { + imageStore( + uOutput, + pos, + imageLoad(uOutput, pos) * uBlock.other); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl deleted file mode 100644 index d7e4619a283a..000000000000 --- a/aten/src/ATen/native/vulkan/glsl/upsampleNearest2d.glsl +++ /dev/null @@ -1,35 +0,0 @@ -#version 450 core -#define PRECISION $precision -layout(std430) buffer; -layout(std430) uniform; -layout(set = 0, rgba16f, binding = 0) writeonly PRECISION uniform image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform constBlock { - int IW; - int IH; - int OW; - int OH; - float scaleX; - float scaleY; -} -uConstBlock; - -layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; - -void main() { - ivec3 pos = ivec3(gl_GlobalInvocationID); - int ow = uConstBlock.OW; - int oh = uConstBlock.OH; - if (pos.x < ow && pos.y < oh) { - int iw = uConstBlock.IW; - int ih = uConstBlock.IH; - float srcX = float(pos.x) * uConstBlock.scaleX; - int x1 = int(floor(srcX)); - int x11 = clamp(x1, 0, iw - 1); - float srcY = float(pos.y) * uConstBlock.scaleY; - int y1 = int(floor(srcY)); - int y11 = clamp(y1, 0, ih - 1); - vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0); - imageStore(uOutput, pos, outValue); - } -} diff --git a/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl new file mode 100644 index 000000000000..9e0da8bf6211 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/upsample_nearest2d.glsl @@ -0,0 +1,39 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; +layout(std430) uniform; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform restrict Block { + int input_width; + int input_height; + int output_width; + int output_height; + float scale_x; + float scale_y; +} +uBlock; + +layout(local_size_x_id = 1, local_size_y_id = 2, local_size_z_id = 3) in; + +void main() { + ivec3 pos = ivec3(gl_GlobalInvocationID); + const int ow = uBlock.output_width; + const int oh = uBlock.output_height; + if (pos.x < ow && pos.y < oh) { + const int iw = uBlock.input_width; + const int ih = uBlock.input_height; + float srcX = float(pos.x) * uBlock.scale_x; + int x1 = int(floor(srcX)); + int x11 = clamp(x1, 0, iw - 1); + float srcY = float(pos.y) * uBlock.scale_y; + int y1 = int(floor(srcY)); + int y11 = clamp(y1, 0, ih - 1); + vec4 outValue = texelFetch(uInput, ivec3(x11, y11, pos.z), 0); + imageStore(uOutput, pos, outValue); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Add.cpp b/aten/src/ATen/native/vulkan/ops/Add.cpp new file mode 100644 index 000000000000..1c1fa216d3f3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Add.cpp @@ -0,0 +1,257 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor add_scalar( + const Tensor& self_arg, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_output.has_image() && v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& add_scalar_( + Tensor& self_arg, + const Scalar other, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + other.to() * alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +Tensor add_tensor( + const Tensor& self_arg, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image() && v_other.has_image()) { + const struct { + uint32_t width, height, channels; + float alpha; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& add_tensor_( + Tensor& self_arg, + const Tensor& other_arg, + const Scalar alpha) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place add is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + const Tensor other = other_arg.is_vulkan() ? other_arg : other_arg.vulkan(); + const vTensor& v_other = convert(other); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image() && v_other.has_image()) { + const struct { + uint32_t width, height, channels; + float alpha; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + alpha.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(add_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_other.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("add.Scalar", TORCH_FN(add_scalar)); + m.impl("add_.Scalar", TORCH_FN(add_scalar_)); + m.impl("add.Tensor", TORCH_FN(add_tensor)); + m.impl("add_.Tensor", TORCH_FN(add_tensor_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp new file mode 100644 index 000000000000..31627296ca55 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -0,0 +1,142 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor clamp( + const Tensor& self_arg, + const c10::optional min_value, + const c10::optional max_value) { + if (!min_value && !max_value) { + TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None"); + } + + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_output.has_image() && v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float min_value; + float max_value; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + min_value ? min_value->to() : -std::numeric_limits::infinity(), + max_value ? max_value->to() : std::numeric_limits::infinity(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(clamp), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& clamp_( + Tensor& self_arg, + const c10::optional min_value, + const c10::optional max_value) { + api::Context* const context = api::context(); + if (!min_value && !max_value) { + TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None"); + } + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place clamp is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float min_value; + float max_value; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + min_value ? min_value->to() : -std::numeric_limits::infinity(), + max_value ? max_value->to() : std::numeric_limits::infinity(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(clamp_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("clamp", TORCH_FN(clamp)); + m.impl("clamp_", TORCH_FN(clamp_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index 121b40cbdb4b..91fc585fe193 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -6,4 +6,42 @@ #include #include +namespace at { +namespace native { +namespace vulkan { + +template +inline constexpr To safe_downcast_internal(const From v) { + typedef std::common_type_t Type; + constexpr Type min{static_cast(std::numeric_limits::lowest())}; + constexpr Type max{static_cast(std::numeric_limits::max())}; + TORCH_CHECK(min <= v && v <= max, "Cast failed: out of range"); + return static_cast(v); +} + +template +inline constexpr bool is_signed_to_unsigned() { + return std::is_signed::value && std::is_unsigned::value; +} + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + TORCH_CHECK(v >= From{}, "Cast failed: negative signed to unsigned"); + return safe_downcast_internal(v); +} + +template < + typename To, + typename From, + std::enable_if_t(), bool> = true> +inline constexpr To safe_downcast(const From v) { + return safe_downcast_internal(v); +} + +} // namespace vulkan +} // namespace native +} // namespace at #endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp new file mode 100644 index 000000000000..2f74d1be00ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -0,0 +1,149 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src) { + // X -> Vulkan + if (at::kVulkan == self.device().type()) { + vTensor& v_self = convert(self); + + // CPU -> Vulkan + if (at::kCPU == src.device().type()) { + // Requesting write-only host access to the tensor never triggers a sync + // as the contents will be overwritten regardless. Having said that, + // appropriate barriers are inserted automatically if WAR or WAW hazards + // are detected. Examples of such scenario for instance are if any of + // these async operations are on going in the background on 'self': + // - On discrete systems: + // * buffer-to-staging transfers + // * staging-to-buffer transfers + // - On UMA buffer is an alias for staging and accessible both on host + // and device. Consequently: + // * buffer-to-image NHWC -> NC4HW packing + // * image-to-buffer NC4HW -> NHWC unpacking + + using Future = vTensor::Future; + Future v_self_future = v_self.host(); + + // This wait() will be a no-op if no hazards are detected, including the + // obvious, yet important, special case of 'self' being an empty tensor. + + Future::Payload v_self_payload = v_self_future.wait(); + + memcpy( + v_self_payload.get(), + src.contiguous().data_ptr(), + std::min(src.nbytes(), self.nbytes())); + } + // Vulkan -> Vulkan + else if (at::kVulkan == src.device().type()) { + api::Command::Buffer command_buffer = api::context()->command().pool.allocate(); + command_buffer.begin(); + + command_buffer.copy( + // - Read-only access is implied on const tensors. Memory barriers + // are automatically inserted if a RAW hazard is detected. + // - Recording any potential pending sync operations into the same + // command buffer prevents an expensive queue submission. + convert(src).buffer(command_buffer), + // - Write-only access never triggers a sync as the contents will be + // overwritten regardless. Having said that, appropriate barriers + // are inserted automatically if WAR or WAW hazards are detected. + // - Recording pending sync operations into the same command buffer + // prevents an expensive queue submission. + v_self.buffer(command_buffer, vTensor::Access::Write)); + + command_buffer.end(); + command_buffer.submit(api::context()->gpu().queue); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + } + // Vulkan -> X + else if (at::kVulkan == src.device().type()) { + const vTensor& v_src = convert(src); + + { + // Similar notes as above applies, with the additional consideration of + // potential syncs on read accesses. Namely, + // - on discrete systems, if the (staging, buffer, image) trio, or + // - on UMA, if the (buffer, image) duo + // have gone out of sync as a result of one processor writing to one + // resource which is then either accessed as an another resource type on + // the same or another processor. Same considerations regarding hazard + // avoidance as above applies. + + using Future = vTensor::Future; + const Future v_src_future = v_src.host(); + + // Vulkan -> CPU + if (at::kCPU == self.device().type()) { + // This wait() is a no-op if data is not out of sync. More often than + // not though, waits here are expected as the GPU catches up with + // compute submitted from CPU. + + const Future::Payload v_src_payload = v_src_future.wait(); + + memcpy( + self.data_ptr(), + v_src_payload.get(), + std::min(src.nbytes(), self.nbytes())); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + } + + // + // WARNING + // + + // This is not great. We almost never want to flush the GPU pipeline as + // that has far reaching consequences, especially if PyTorch is not the only + // process accessing the GPU. If we have done our job properly, above + // synchronization mechanisms should be enough to ensure correctness at a more + // modest cost, as there is no need to flush the entirety of jobs in flight + // if one is only interested on waiting on computation affecting one single + // tensor to finish. + // + // Having said that, we still do need to release all pool resources at one + // point per inference run or we will run out of memory otherwise. There is + // no perfect answer to this problem that checks all boxes, which leaves us + // with one of several design decisions: + // + // 1) Use graph mode to gain an understanding of the computation graph, + // itself allowing us to place pool purges intelligently. Best option + // for performance and memory consumption. Not without its downsides if + // flexibility is a top priority. + // 2) If on eager mode, and hence are seeing operations one at a time, expose + // this release of resources to the user as a Python / C++ function. This + // makes for suboptimal user experience but is efficient in terms of + // performance. + // 3) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, release all resources somewhere ... like here. This is + // not ideal since it requires a pipeline flush to make sure these objects + // are not already in use by a workload in flight. Cannot do much better + // within the constraints of this approach. Good for user experience, + // suboptimal for performance. + // 4) If on eager mode, and interested in keeping this bookkeeping transparent + // to the user, and performance does not matter, make CPU and GPU run in + // lockstep. Obviously this is just bad. Mentioned for the sake of + // completeness. + + api::context()->flush(); + } + else { + TORCH_INTERNAL_ASSERT(false, "Unsupported!"); + } + + return self; +} + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Copy.h b/aten/src/ATen/native/vulkan/ops/Copy.h new file mode 100644 index 000000000000..e69af06357c5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Copy.h @@ -0,0 +1,19 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +Tensor& copy_(Tensor& self, const Tensor& src); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp new file mode 100644 index 000000000000..cf39d845178a --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -0,0 +1,139 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor addmm( + const Tensor& self_arg, + const Tensor& mat1_arg, + const Tensor& mat2_arg, + const Scalar beta, + const Scalar alpha) { + api::Context* const context = api::context(); + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + const Tensor mat1 = mat1_arg.is_vulkan() ? mat1_arg : mat1_arg.vulkan(); + const vTensor& v_mat1 = convert(mat1); + + const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); + const vTensor& v_mat2 = convert(mat2); + + vTensor v_output{ + context, + {mat1.sizes()[0], mat2.sizes()[1]}, + self.options() + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float beta, alpha; + uint32_t k; + } block { + mat2_arg.sizes()[1], + mat1_arg.sizes()[0], + 1u, + beta.to(), + alpha.to(), + mat1_arg.sizes()[1], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + }, + VK_KERNEL(addmm), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_mat1.image(command_buffer), + v_mat2.image(command_buffer), + context->resource().pool.uniform(block).object, + v_self.image(command_buffer)); + } else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { + api::Context* const context = api::context(); + const Tensor mat1 = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_mat1 = convert(mat1); + + const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); + const vTensor& v_mat2 = convert(mat2); + + vTensor v_output{ + context, + {mat1.sizes()[0], mat2.sizes()[1]}, + mat1.options() + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_mat1.has_image() && v_mat2.has_image()) { + const struct { + uint32_t width, height, channels, k; + } block { + mat2.sizes()[1], + mat1.sizes()[0], + 1u, + mat1.sizes()[1], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mm), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_mat1.image(command_buffer), + v_mat2.image(command_buffer), + context->resource().pool.uniform(block).object); + } else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("addmm", TORCH_FN(addmm)); + m.impl("mm", TORCH_FN(mm)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Mul.cpp b/aten/src/ATen/native/vulkan/ops/Mul.cpp new file mode 100644 index 000000000000..76ddf0d41bd9 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mul.cpp @@ -0,0 +1,130 @@ +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor mul_scalar( + const Tensor& self_arg, + const Scalar other) { + api::Context* const context = api::context(); + + const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + self.sizes(), + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_output.has_image() && v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_output.extents().width, + v_output.extents().height, + v_output.extents().depth, + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor& mul_scalar_( + Tensor& self_arg, + const Scalar other) { + api::Context* const context = api::context(); + + TORCH_CHECK( + self_arg.is_vulkan(), + "Vulkan: In-place mul_scalar is only supported on Vulkan tensors."); + + vTensor& v_self = convert(self_arg); + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t width, height, channels; + float other; + } block { + v_self.extents().width, + v_self.extents().height, + v_self.extents().depth, + other.to(), + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(mul_scalar_), + v_self.extents(), + // Read-Write access triggers an async synchronization if necessory + // and inserts appropriate barriers if hazards are detected. + v_self.image(command_buffer, vTensor::Access::Read | vTensor::Access::Write), + // Object lifetime is managed by the resource pool. + // It is OK not to keep track of the handle. + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return self_arg; +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("mul.Scalar", TORCH_FN(mul_scalar)); + m.impl("mul_.Scalar", TORCH_FN(mul_scalar_)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Pool.cpp b/aten/src/ATen/native/vulkan/ops/Pool.cpp new file mode 100644 index 000000000000..8c2c05ff26a3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Pool.cpp @@ -0,0 +1,168 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor adaptive_avg_pool2d(const at::Tensor& input_arg, IntArrayRef output_size) { + TORCH_INTERNAL_ASSERT( + input_arg.dim() == 4, + "vulkan_adaptive_avg_pool2d expects 4-dimensional input"); + + api::Context* const context = api::context(); + const vTensor& v_input = convert(input_arg); + vTensor v_output{ + context, + {input_arg.sizes()[0], input_arg.sizes()[1], output_size[0], output_size[1]}, + input_arg.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_input.has_image()) { + const struct { + uint32_t input_width, input_height, output_width, output_height; + } block { + input_arg.sizes()[3], + input_arg.sizes()[2], + output_size[1], + output_size[0], + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(adaptive_avg_pool2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_input.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +Tensor avg_pool2d( + const Tensor& self, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { + TORCH_CHECK( + kernel_size.size() == 1 || kernel_size.size() == 2, + "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"); + const int kernel_height = safe_downcast(kernel_size[0]); + const int kernel_width = + kernel_size.size() == 1 ? kernel_height : safe_downcast(kernel_size[1]); + + TORCH_CHECK( + stride.empty() || stride.size() == 1 || stride.size() == 2, + "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"); + const int dH = stride.empty() ? kernel_height : safe_downcast(stride[0]); + const int dW = stride.empty() + ? kernel_width + : stride.size() == 1 ? dH : safe_downcast(stride[1]); + + TORCH_CHECK( + padding.size() == 1 || padding.size() == 2, + "avg_pool2d: padding must either be a single int, or a tuple of two ints"); + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int64_t input_batch = self.sizes()[0]; + const int64_t input_channels = self.sizes()[1]; + const int64_t input_height = self.sizes()[2]; + const int64_t input_width = self.sizes()[3]; + + const int64_t output_height = + pooling_output_shape(input_height, kernel_height, padH, dH, 1, ceil_mode); + const int64_t output_width = + pooling_output_shape(input_width, kernel_width, padW, dW, 1, ceil_mode); + + pool2d_shape_check( + self, kernel_height, kernel_width, dH, dW, padH, padW, 1, 1, input_channels, input_height, input_width, output_height, output_width); + + api::Context* const context = api::context(); + + const vTensor& v_self = convert(self); + + vTensor v_output{ + context, + {input_batch, input_channels, output_height, output_width}, + self.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (v_self.has_image()) { + const struct { + uint32_t input_width, input_height, input_channels, input_size_stub; + uint32_t output_width, output_height, output_channels, output_size_stub; + uint32_t kernel_width, kernel_height; + uint32_t stride_x, stride_y; + uint32_t padding_x, padding_y; + uint32_t dilate_x, dilate_y; + } block { + input_width, input_height, input_batch * input_channels, 0u, + output_width, output_height, input_batch * input_channels, 0u, + kernel_width, kernel_height, + dW, dH, + padW, padH, + 1u, 1u + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(avg_pool2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_self.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); + +} +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("_adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d)); + m.impl("avg_pool2d", TORCH_FN(avg_pool2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index a51fd972d19a..a5baf716069f 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -22,11 +22,7 @@ VkDeviceSize bytes( size *= extents.width * extents.height * (4u * extents.depth); } else { - size = std::accumulate( - sizes.cbegin(), - sizes.cend(), - size, - std::multiplies()); + size *= prod_intlist(sizes); } return size; diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp new file mode 100644 index 000000000000..2a95751e59bd --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -0,0 +1,80 @@ +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +Tensor upsample_nearest2d( + const Tensor& input_arg, + const IntArrayRef output_sizes, + const c10::optional scales_h, + const c10::optional scales_w) { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + vTensor v_output{ + context, + {input_arg.sizes()[0], input_arg.sizes()[1], output_sizes[0], output_sizes[1]}, + input.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + const float scale_x = compute_scales_value(scales_w, input_arg.sizes()[3], output_sizes[1]); + const float scale_y = compute_scales_value(scales_h, input_arg.sizes()[2], output_sizes[0]); + if (v_input.has_image()) { + const struct { + uint32_t input_width, input_height, output_width, output_height; + float scale_x, scale_y; + } block { + input_arg.sizes()[3], + input_arg.sizes()[2], + output_sizes[1], + output_sizes[0], + scale_x, + scale_y + }; + + context->dispatch( + command_buffer, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + }, + VK_KERNEL(upsample_nearest2d), + v_output.extents(), + v_output.image(command_buffer, vTensor::Access::Write), + v_input.image(command_buffer), + context->resource().pool.uniform(block).object); + } + else { + TORCH_CHECK(false, "Not implemented!"); + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl("upsample_nearest2d", TORCH_FN(upsample_nearest2d)); +} + +#endif /* USE_VULKAN_API */ + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp index 40ecbf6d5720..1c79ac186c1a 100644 --- a/aten/src/ATen/quantized/QTensorImpl.cpp +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -5,7 +5,7 @@ namespace at { QTensorImpl::QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer) : TensorImpl(std::move(storage), key_set, data_type), quantizer_(quantizer) {} diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index c2728c7aab46..efce432d5863 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -18,7 +18,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { QTensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, QuantizerPtr quantizer); // TODO: Expose in PyTorch Frontend diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 26e9fd9f21fa..41f31968688d 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -376,6 +377,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) { name_ = StringView(name); sequence_nr_ = sequence_nr; thread_id_ = currentThreadId(); + operator_name_.reset(); manager().runStartCallbacks(*this); } @@ -387,6 +389,21 @@ void RecordFunction::before(std::string name, int64_t sequence_nr) { name_ = StringView(std::move(name)); sequence_nr_ = sequence_nr; thread_id_ = currentThreadId(); + operator_name_.reset(); + + manager().runStartCallbacks(*this); +} + +void RecordFunction::before( + c10::OperatorHandle const& op, + int64_t sequence_nr) { + if (!active) { + return; + } + sequence_nr_ = sequence_nr; + thread_id_ = currentThreadId(); + operator_name_ = op.operator_name(); + name_ = StringView(op.schema().name()); manager().runStartCallbacks(*this); } diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index cf839ad4a188..db2ee221a09a 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -1,12 +1,18 @@ #pragma once #include -#include +#include #include +#include +#include #include #include +namespace c10 { +class CAFFE2_API OperatorHandle; +} + namespace at { // Kind of record function scope; @@ -147,6 +153,7 @@ struct TORCH_API RecordFunction { // start callbacks void before(const char* name, int64_t sequence_nr = -1); void before(std::string name, int64_t sequence_nr = -1); + void before(c10::OperatorHandle const& op, int64_t sequence_nr = -1); // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); @@ -178,6 +185,10 @@ struct TORCH_API RecordFunction { return handle_; } + inline c10::optional operator_name() const { + return operator_name_; + } + inline void setHandle(RecordFunctionHandle handle) { handle_ = handle; } @@ -213,6 +224,8 @@ struct TORCH_API RecordFunction { int64_t sequence_nr_ = -1; std::vector inputs_; + c10::optional operator_name_; + // Kind of scope this RecordFunction is observing const RecordScope scope_; diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index 7c9aa96f6e70..589121af07ef 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -3,13 +3,7 @@ #include #include -#include -#include -#include #include -#ifdef USE_VULKAN -#include -#endif namespace at { diff --git a/aten/src/ATen/templates/SchemaRegister.cpp b/aten/src/ATen/templates/SchemaRegister.cpp deleted file mode 100644 index f48e732f4760..000000000000 --- a/aten/src/ATen/templates/SchemaRegister.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// ${generated_comment} - -#include -#include - -using namespace at; - -TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { - ${schema_registrations} -} diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp deleted file mode 100644 index b0a4fed24a63..000000000000 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ /dev/null @@ -1,43 +0,0 @@ -// required for old g++ to compile PRId64 macros, see -// https://github.com/pytorch/pytorch/issues/3571 -// for context -#define __STDC_FORMAT_MACROS - -#include - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -$extra_cuda_headers - -namespace at { - -namespace ${Type} { - -${type_derived_method_definitions} - -} // namespace ${Type} - -TORCH_LIBRARY_IMPL(aten, ${Backend}, m) { - ${function_registrations}; -} - -} // namespace at diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 3f6292f41178..1fd7eb7e5d9f 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -208,7 +209,7 @@ class CAFFE2_API Tensor { return impl_->strides(); } // See impl::get_opt_names in ATen/NamedTensor.h for docs. - optional opt_names() const { + c10::optional opt_names() const { return impl::get_opt_names(unsafeGetTensorImpl()); } // See impl::get_names in ATen/NamedTensor.h for docs. diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/TypeDefault.cpp index 145a5b421019..4cd1d1586d6a 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/TypeDefault.cpp @@ -64,6 +64,10 @@ TORCH_LIBRARY(aten, m) { m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)"); } +TORCH_LIBRARY_IMPL(aten, Math, m) { + ${math_function_registrations}; +} + TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { ${default_backend_function_registrations}; } diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index d65c13ae8d97..3275ab76ef62 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -5,12 +5,9 @@ #define __STDC_FORMAT_MACROS #endif -#include - // ${generated_comment} -$storage_tensor_headers -#include +#include #include #include #include diff --git a/aten/src/ATen/templates/TypeDerived.h b/aten/src/ATen/templates/TypeDerived.h deleted file mode 100644 index 4b571f40383f..000000000000 --- a/aten/src/ATen/templates/TypeDerived.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -// ${generated_comment} - -#include -#include -#include -#include -#include -#include -#include -#include - -$extra_cuda_headers - -namespace c10 { -struct Storage; -} - -namespace at { - -class Tensor; -using TensorList = ArrayRef; - -class Context; -struct Generator; - -struct Quantizer; -// This is temporary typedef to enable Quantizer in aten native function API -// we'll remove them when we are actually exposing Quantizer class -// to frontend -using ConstQuantizerPtr = const c10::intrusive_ptr&; - -namespace ${Type} { - ${type_derived_method_declarations} -} - -} // namespace at diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 4268db33fa16..067902c0a3b7 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -76,13 +76,18 @@ list(APPEND ATen_HIP_TEST_SRCS # ${CMAKE_CURRENT_SOURCE_DIR}/hip/hip_stream_test.cpp list(APPEND ATen_VULKAN_TEST_SRCS - ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_test.cpp) +if(USE_VULKAN_API) + list(APPEND ATen_VULKAN_TEST_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp) +endif() + list(APPEND ATen_MOBILE_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp) list(APPEND ATen_VEC256_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp index a2b64618ccfe..71ce5053334b 100644 --- a/aten/src/ATen/test/quantized_test.cpp +++ b/aten/src/ATen/test/quantized_test.cpp @@ -101,9 +101,8 @@ TEST(TestQTensor, EmptyQuantized) { int zero_point = 10; int val = 100; int numel = 10; - Tensor q = at::_empty_affine_quantized({numel}, - at::device(at::kCPU).dtype(kQUInt8), - scale, zero_point); + Tensor q = at::_empty_affine_quantized( + {numel}, at::device(at::kCPU).dtype(kQUInt8), scale, zero_point); // Assigning to QTensor auto* q_data = q.data_ptr(); for (int i = 0; i < numel; ++i) { @@ -142,7 +141,66 @@ TEST(TestQTensor, EmptyPerchannelQuantized) { for (int i = 0; i < numel; ++i) { ASSERT_EQ( r_data[i], - (val - zero_points[i].item().to()) * - scales[i].item().to()); + (val - zero_points[i].item().to()) * scales[i].item().to()); + } +} + +TEST(TestQTensor, QuantizePerChannel4d) { + int C = 32, H = 10, W = 10; + auto scales = rand({C}).toType(kDouble); + auto zero_points = randint(10, {C}).toType(kLong); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty({1, C, H, W}, at::device(at::kCPU).dtype(kFloat)); + auto* tensor_data = tensor.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + for (int e = 0; e < H * W; ++e, ++i) { + tensor_data[i] = e; + } + } + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + auto scale = scales[c].item(); + auto zero_point = zero_points[c].item(); + for (int e = 0; e < H * W; ++e, ++i) { + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min((int)round(e / scale) + zero_point, 255); + ASSERT_EQ((int)q_data[i], qval); + } + } +} + +TEST(TestQTensor, QuantizePerChannel4dChannelsLast) { + int C = 32, H = 10, W = 10; + auto scales = rand({C}).toType(kFloat); + auto zero_points = randint(10, {C}).toType(kInt); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty( + {1, C, H, W}, + at::device(at::kCPU).dtype(kFloat).memory_format( + at::MemoryFormat::ChannelsLast)); + auto* tensor_data = tensor.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + tensor_data[i] = e; + } + } + + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + auto scale = scales[c].item(); + auto zero_point = zero_points[c].item(); + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min((int)round(e / scale) + zero_point, 255); + ASSERT_EQ((int)q_data[i], qval); + } } } diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index dc73460b3728..68c0b4f3f71a 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -128,3 +128,13 @@ TEST(TestScalar, TestScalar) { ASSERT_EQ(float_one.item(), 1); ASSERT_EQ(float_one.item(), 1); } + +TEST(TestScalar, TestConj) { + Scalar int_scalar = 257; + Scalar float_scalar = 3.0; + Scalar complex_scalar = c10::complex(2.3, 3.5); + + ASSERT_EQ(int_scalar.conj().toInt(), 257); + ASSERT_EQ(float_scalar.conj().toDouble(), 3.0); + ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex(2.3, -3.5)); +} diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index c7ac15544a99..5596dc8e1d67 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1,11 +1,232 @@ -#include +#ifdef USE_VULKAN_API +#include #include -#ifdef USE_VULKAN_API +// TODO: These functions should move to a common place. + +namespace { + +bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { + float maxValue = 0.0f; + + for (const auto& tensor : inputs) { + maxValue = fmax(tensor.abs().max().item(), maxValue); + } + + return diff.abs().max().item() < (2e-6 * maxValue); +} + +bool almostEqual(const at::Tensor& a, const at::Tensor& b) { + return checkRtol(a - b, {a, b}); +} + +bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { + return (a - b).abs().max().item() == 0.0f; +} + +} // namespace namespace { +TEST(VulkanAPITest, add) { + const auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + const auto c_cpu = at::add(a_cpu, b_cpu, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_vulkan, 2.1f); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const auto b_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + const auto b_vulkan = b_cpu.vulkan(); + + a_cpu.add_(b_cpu, 2.1f); + a_vulkan.add_(b_vulkan, 2.1f); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_scalar) { + const auto a_cpu = at::rand({1, 1, 1, 1}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::add(a_cpu, b_scalar, 2.1f); + const auto c_vulkan = at::add(a_vulkan, b_scalar, 2.1f); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, add_scalar_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.add_(b_scalar, 2.1f); + a_vulkan.add_(b_scalar, 2.1f); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, mul_scalar) { + const auto a_cpu = at::rand({17, 213, 213, 7}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + const auto c_cpu = at::mul(a_cpu, b_scalar); + const auto c_vulkan = at::mul(a_vulkan, b_scalar); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, mul_scalar_) { + auto a_cpu = at::rand({11, 7, 139, 109}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + const float b_scalar = 3.1415f; + + a_cpu.mul_(b_scalar); + a_vulkan.mul_(b_scalar); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanAPITest, clamp) { + const auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float min_value = 0.2f; + const float max_value = 0.8f; + + const auto c_cpu = at::clamp(a_cpu, min_value, max_value); + const auto c_vulkan = at::clamp(a_vulkan, min_value, max_value); + + ASSERT_TRUE(almostEqual(c_cpu, c_vulkan.cpu())); +} + +TEST(VulkanAPITest, clamp_) { + const auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto a_vulkan = a_cpu.vulkan(); + + const float min_value = 0.2f; + const float max_value = 0.8f; + + a_cpu.clamp_(min_value, max_value); + a_vulkan.clamp_(min_value, max_value); + + ASSERT_TRUE(almostEqual(a_cpu, a_vulkan.cpu())); +} + +TEST(VulkanTest, addmm) { + auto t_m1 = at::rand({2, 2}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_m2 = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_b = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + + float beta = 100; + float alpha = 2; + auto t_out_expected = at::addmm(t_b, t_m1, t_m2, beta, alpha); + + auto tv_m1 = t_m1.vulkan(); + auto tv_m2 = t_m2.vulkan(); + auto tv_b = t_b.vulkan(); + auto tv_out = at::addmm(tv_b, tv_m1, tv_m2, beta, alpha); + auto t_out = tv_out.cpu(); + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, mm) { + auto t_m1 = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); + auto t_m2 = at::rand({3, 2}, at::device(at::kCPU).dtype(at::kFloat)); + + auto t_out_expected = t_m1.mm(t_m2); + + auto tv_m1 = t_m1.vulkan(); + auto tv_m2 = t_m2.vulkan(); + auto tv_out = tv_m1.mm(tv_m2); + auto t_out = tv_out.cpu(); + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, adaptive_avg_pool2d) { + auto t_in = + at::rand({1, 2, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::adaptive_avg_pool2d(t_in, {3, 3}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::adaptive_avg_pool2d(tv_in, {3, 3}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, upsample_nearest2d) { + auto t_in = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::upsample_nearest2d(t_in, {4, 6}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::upsample_nearest2d(tv_in, {4, 6}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanTest, avg_pool2d) { + if (!at::is_vulkan_available()) + return; + + auto t_in = + at::rand({1, 3, 7, 7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto t_out_expected = at::avg_pool2d(t_in, {2, 2}, {1}, {0}, {1}); + auto tv_in = t_in.vulkan(); + + auto tv_out = at::avg_pool2d(tv_in, {2, 2}, {1}, {0}, {1}); + auto t_out = tv_out.cpu(); + + const auto check = almostEqual(t_out, t_out_expected); + if (!check) { + std::cout << "expected:\n" << t_out_expected << std::endl; + std::cout << "got:\n" << t_out << std::endl; + } + ASSERT_TRUE(check); +} + +TEST(VulkanAPITest, copy) { + const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); + ASSERT_TRUE(exactlyEqual(cpu, cpu.vulkan().cpu())); +} + TEST(VulkanAPITest, empty) { ASSERT_NO_THROW(at::empty({1, 17, 41, 53}, at::device(at::kVulkan).dtype(at::kFloat))); } diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp index b78f8c7a3035..8d5c28daa796 100644 --- a/aten/src/TH/THStorageFunctions.hpp +++ b/aten/src/TH/THStorageFunctions.hpp @@ -8,6 +8,7 @@ #include #include +#include // Note [Weak references for intrusive refcounting] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index 6a776f4d0a17..c0fb51f53e45 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -2,7 +2,6 @@ #define TH_GENERIC_FILE "TH/generic/THLapack.cpp" #else - TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 2494e21791e4..76d7d7bc48d8 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -410,7 +410,8 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau) { if (a == NULL) a = ra_; THArgCheck(THTensor_nDimension(a) == 2, 1, "'input' should be 2 dimensional"); - THArgCheck(!a->is_empty(), 1, "'input' should not be empty"); + THArgCheck(!a->is_empty(), 2, "'input' should not be empty"); + THArgCheck(!tau->is_empty(), 3, "'tau' should not be empty"); THTensor *ra__ = NULL; ra__ = THTensor_(cloneColumnMajor)(ra_, a); diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 1d0daf1206de..3f56494e5999 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -31,7 +31,6 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index, TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim); TH_API void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim); -TH_API accreal THTensor_(trace)(THTensor *t); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 93708556dfb5..2faeadd76e01 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -252,27 +252,6 @@ void THTensor_(preserveReduceDimSemantics)( #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ -accreal THTensor_(trace)(THTensor *t) -{ - scalar_t *t_data = t->data(); - accreal sum = 0; - int64_t i = 0; - int64_t t_stride_0, t_stride_1, t_diag_size; - - THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix"); - - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - t_diag_size = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)); - while(i < t_diag_size) - { - sum += t_data[i*(t_stride_0+t_stride_1)]; - i++; - } - - return sum; -} - /* Implementation of the Quickselect algorithm, based on Nicolas Devillard's public domain implementation at http://ndevilla.free.fr/median/median/ Adapted similarly to the above Quicksort algorithm. */ diff --git a/c10/core/DefaultDtype.cpp b/c10/core/DefaultDtype.cpp index c4f420ab6e22..583d4452bfbd 100644 --- a/c10/core/DefaultDtype.cpp +++ b/c10/core/DefaultDtype.cpp @@ -3,26 +3,32 @@ namespace c10 { static auto default_dtype = caffe2::TypeMeta::Make(); -static auto default_dtype_as_scalartype = typeMetaToScalarType(default_dtype); +static auto default_dtype_as_scalartype = default_dtype.toScalarType(); static auto default_complex_dtype = caffe2::TypeMeta::Make>(); void set_default_dtype(caffe2::TypeMeta dtype) { - default_dtype = std::move(dtype); - default_dtype_as_scalartype = typeMetaToScalarType(default_dtype); - if(default_dtype_as_scalartype == ScalarType::Double) { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); - } else { - default_complex_dtype = std::move(caffe2::TypeMeta::Make>()); + default_dtype = dtype; + default_dtype_as_scalartype = default_dtype.toScalarType(); + switch (default_dtype_as_scalartype) { + case ScalarType::Half: + default_complex_dtype = ScalarType::ComplexHalf; + break; + case ScalarType::Double: + default_complex_dtype = ScalarType::ComplexDouble; + break; + default: + default_complex_dtype = ScalarType::ComplexFloat; + break; } } -const caffe2::TypeMeta& get_default_dtype() { +const caffe2::TypeMeta get_default_dtype() { return default_dtype; } ScalarType get_default_dtype_as_scalartype() { return default_dtype_as_scalartype; } -const caffe2::TypeMeta& get_default_complex_dtype() { +const caffe2::TypeMeta get_default_complex_dtype() { return default_complex_dtype; } } // namespace c10 diff --git a/c10/core/DefaultDtype.h b/c10/core/DefaultDtype.h index eda34b217727..d0a17474bda4 100644 --- a/c10/core/DefaultDtype.h +++ b/c10/core/DefaultDtype.h @@ -9,7 +9,7 @@ class TypeMeta; namespace c10 { C10_API void set_default_dtype(caffe2::TypeMeta dtype); -C10_API const caffe2::TypeMeta& get_default_dtype(); +C10_API const caffe2::TypeMeta get_default_dtype(); C10_API ScalarType get_default_dtype_as_scalartype(); -C10_API const caffe2::TypeMeta& get_default_complex_dtype(); +C10_API const caffe2::TypeMeta get_default_complex_dtype(); } // namespace c10 diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index eaf06a7846a6..36c9ab0d6164 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -53,6 +53,9 @@ const char* toString(DispatchKey t) { case DispatchKey::SparseHIP: return "SparseHIP"; + case DispatchKey::NestedTensor: + return "NestedTensor"; + case DispatchKey::PrivateUse1: return "PrivateUse1"; case DispatchKey::PrivateUse2: @@ -71,6 +74,8 @@ const char* toString(DispatchKey t) { return "AutogradCUDA"; case DispatchKey::AutogradXLA: return "AutogradXLA"; + case DispatchKey::AutogradNestedTensor: + return "AutogradNestedTensor"; case DispatchKey::AutogradPrivateUse1: return "AutogradPrivateUse1"; case DispatchKey::AutogradPrivateUse2: @@ -132,6 +137,8 @@ DispatchKey getAutogradKeyFromBackend(DispatchKey t) { return DispatchKey::AutogradCUDA; case DispatchKey::XLA: return DispatchKey::AutogradXLA; + case DispatchKey::NestedTensor: + return DispatchKey::AutogradNestedTensor; case DispatchKey::PrivateUse1: return DispatchKey::AutogradPrivateUse1; case DispatchKey::PrivateUse2: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 35ad7907ef16..aa4f11fe1439 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -104,6 +104,7 @@ enum class DispatchKey : uint8_t { SparseHIP, // TODO: I think this is not actually used, due to Note // [Masquerading as CUDA] + NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor // Here are reserved backends for user-defined backends, see Note [Private use // DispatchKey] // To see some example about how to use this, check out MSNPU @@ -217,6 +218,7 @@ enum class DispatchKey : uint8_t { AutogradCPU, AutogradCUDA, AutogradXLA, + AutogradNestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor // Here are some reserved pre-autograd keys for user-defined backends, see // Note [Private use DispatchKey] AutogradPrivateUse1, diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 2529d4c4bd51..ef8355ef463c 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -8,6 +8,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | Disp DispatchKey::CPU, DispatchKey::CUDA, DispatchKey::XLA, + DispatchKey::NestedTensor, DispatchKey::PrivateUse1, DispatchKey::PrivateUse2, DispatchKey::PrivateUse3, @@ -45,6 +46,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { return DispatchKeySet(DispatchKey::CUDA); case DispatchKey::AutogradXLA: return DispatchKeySet(DispatchKey::XLA); + case DispatchKey::AutogradNestedTensor: + return DispatchKeySet(DispatchKey::NestedTensor); case DispatchKey::AutogradPrivateUse1: return DispatchKeySet(DispatchKey::PrivateUse1); case DispatchKey::AutogradPrivateUse2: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index abcc4becc3b0..e8a9c70c143d 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -192,6 +192,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradNestedTensor, DispatchKey::AutogradPrivateUse1, DispatchKey::AutogradPrivateUse2, DispatchKey::AutogradPrivateUse3, diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 04bba06a91a5..35aa5d60f001 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -13,4 +13,12 @@ Scalar Scalar::operator-() const { } } +Scalar Scalar::conj() const { + if (isComplex()) { + return Scalar(std::conj(v.z)); + } else { + return *this; + } +} + } // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 19f0d3b90e6f..6151f6d2b150 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -87,7 +87,7 @@ class C10_API Scalar { } Scalar operator-() const; - + Scalar conj() const; ScalarType type() const { if (isComplex()) { return ScalarType::ComplexDouble; diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 8f2acebd84f0..6903cf9f77ce 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -3,9 +3,12 @@ #include #include #include +#include +#include +#include #include +#include #include -#include #include #include @@ -68,6 +71,8 @@ enum class ScalarType : int8_t { NumOptions }; +constexpr uint16_t NumScalarTypes = static_cast(ScalarType::NumOptions); + namespace impl { // These are used to map ScalarTypes to C++ types. @@ -94,7 +99,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) #undef SPECIALIZE_ScalarTypeToCPPType -} +} // namespace impl template struct CppTypeToScalarType; @@ -162,64 +167,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { -#define DEFINE_CASE(ctype, name) \ - case ScalarType::name: \ - return caffe2::TypeMeta::Make(); - - switch (scalar_type) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) - case ScalarType::Undefined: - return caffe2::TypeMeta(); - default: - AT_ERROR( - "Unrecognized Scalartype ", - scalar_type, - " (please report this error)"); - } -#undef DEFINE_CASE -} - -static inline c10::optional tryTypeMetaToScalarType( - caffe2::TypeMeta dtype) { -#define DEFINE_IF(ctype, name) \ - if (dtype == caffe2::TypeMeta::Make()) { \ - return {ScalarType::name}; \ - } - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF) -#undef DEFINE_IF - if (dtype == caffe2::TypeMeta()) { - return {ScalarType::Undefined}; - } - return c10::nullopt; -} - -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { - if (auto scalar_type = tryTypeMetaToScalarType(dtype)) { - return *scalar_type; - } - AT_ERROR( - "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); -} - -inline optional optTypeMetaToScalarType(optional type_meta) { - if (!type_meta.has_value()) { - return c10::nullopt; - } - return typeMetaToScalarType(*type_meta); -} - -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { - if (auto mt = tryTypeMetaToScalarType(m)) { - return (*mt) == t; - } - return false; -} - -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { - return t == m; -} - #define DEFINE_CONSTANT(_, name) \ constexpr ScalarType k##name = ScalarType::name; diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h new file mode 100644 index 000000000000..b6e7f6cf1993 --- /dev/null +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +// these just expose TypeMeta/ScalarType bridge functions in c10 +// TODO move to typeid.h (or codemod away) when TypeMeta et al +// are moved from caffe2 to c10 (see note at top of typeid.h) + +namespace c10 { + +/** + * convert ScalarType enum values to TypeMeta handles + */ +static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { + return caffe2::TypeMeta::fromScalarType(scalar_type); +} + +/** + * convert TypeMeta handles to ScalarType enum values + */ +static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + return dtype.toScalarType(); +} + +/** + * typeMetaToScalarType(), lifted to optional + */ +static inline optional optTypeMetaToScalarType(optional type_meta) { + if (!type_meta.has_value()) { + return c10::nullopt; + } + return type_meta->toScalarType(); +} + +/** + * convenience: equality across TypeMeta/ScalarType conversion + */ +static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { + return m.isScalarType(t); +} + +static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { + return t == m; +} + +static inline bool operator!=(ScalarType t, caffe2::TypeMeta m) { + return !(t == m); +} + +static inline bool operator!=(caffe2::TypeMeta m, ScalarType t) { + return !(t == m); +} + +} // namespace c10 diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 8702ed4fdebf..9f2ca1d2ca07 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -47,13 +47,13 @@ const at::Tensor& TensorImpl::grad() const { TensorImpl::TensorImpl( Storage&& storage, DispatchKeySet key_set, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl(std::move(storage), key_set, data_type, storage.device()) {} -TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) +TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl({}, key_set, data_type, std::move(device_opt)) {} -TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta& data_type, +TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional device_opt) : storage_(std::move(storage)), sizes_{0}, @@ -61,9 +61,11 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2:: numel_(0), data_type_(data_type), device_opt_(device_opt) { + + init_bitfields(); + if (!key_set.empty()) { - AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() || - device_opt_.has_value()); + TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it C10_LOG_API_USAGE_ONCE("tensor.create"); } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index ad636e51ff12..da849b049b65 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -322,24 +322,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { TensorImpl( Storage&& storage, DispatchKeySet, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional device_opt); + TensorImpl(DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional device_opt); // Legacy constructors so I don't have to go update call sites. // TODO: When Variable is added, delete these constructors TensorImpl( Storage&& storage, DispatchKey dispatch_key, - const caffe2::TypeMeta& data_type) + const caffe2::TypeMeta data_type) : TensorImpl( std::move(storage), DispatchKeySet(dispatch_key), data_type) {} - TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta& data_type, c10::optional device_opt) + TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta data_type, c10::optional device_opt) : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} private: @@ -347,7 +347,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional); + TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -665,7 +665,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Returns the TypeMeta of a tensor, which describes what data type * it is (e.g., int, float, ...) */ - const caffe2::TypeMeta& dtype() const { + const caffe2::TypeMeta dtype() const { return data_type_; } @@ -1235,10 +1235,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void ShareExternalPointer( DataPtr&& data_ptr, - const caffe2::TypeMeta& data_type, + const caffe2::TypeMeta data_type, size_t size_bytes) { TORCH_CHECK( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); if (!size_bytes) { @@ -1275,7 +1275,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * If the existing data does not match the desired type, it will be deleted * and a new storage will be created. */ - inline void* raw_mutable_data(const caffe2::TypeMeta& meta) { + inline void* raw_mutable_data(const caffe2::TypeMeta meta) { // For 0-size tensors it's fine to return any pointer (including nullptr) if (data_type_ == meta && storage_initialized()) { return static_cast(static_cast(storage_.data()) + storage_offset_ * meta.itemsize()); @@ -1369,7 +1369,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_storage_and_dtype( at::Storage storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { set_storage_keep_dtype(storage); data_type_ = data_type; } @@ -1675,36 +1675,47 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // INVARIANT: named_tensor_meta_ != nullptr <==> key_set_.has(DispatchKey::Named) DispatchKeySet key_set_; - // You get to have eight byte-size fields here, before you - // should pack this into a bitfield. + // Tensor is contiguous bool is_contiguous_ = true; + // default member initializers for bit-fields only available with -std=c++2a or -std=gnu++2a + inline void init_bitfields() { + is_channels_last_ = false; + is_channels_last_contiguous_ = false; + is_channels_last_3d_ = false; + is_channels_last_3d_contiguous_ = false; + is_non_overlapping_and_dense_ = false; + is_wrapped_number_ = false; + allow_tensor_metadata_change_ = true; + reserved_ = false; + } + // Tensor is stored in the channels last 2d memory format, when dimensions // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_ = false; + bool is_channels_last_ : 1; // Channels last contiguous tensor is channel last tensor which occupies // contiguous memory block. - bool is_channels_last_contiguous_ = false; + bool is_channels_last_contiguous_ : 1; // Tensor is stored in the channels last 3d memory format, when dimensions // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (< N-strides) // (If size of any dimension is equal to 1, this dimension strides value // is not taken into account). - bool is_channels_last_3d_ = false; + bool is_channels_last_3d_ : 1; // Channels last 3d contiguous tensor is channel last 3d tensor which occupies // contiguous memory block. - bool is_channels_last_3d_contiguous_ = false; + bool is_channels_last_3d_contiguous_ : 1; // Dense tensor is the tensor that store values in a contiguous block of memory. // Non-overlapping tensor is the tensor in which elements occupy individual // non-repetitive memory. - bool is_non_overlapping_and_dense_ = false; + bool is_non_overlapping_and_dense_ : 1; - bool is_wrapped_number_ = false; + bool is_wrapped_number_ : 1; // NOTE [ Metadata Change for a Detached Tensor ] // @@ -1721,14 +1732,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // NOTE: For a full list of tensor metadata fields, please see // `copy_tensor_metadata()` in TensorImpl and its subclasses to find // which fields are copied by value. - bool allow_tensor_metadata_change_ = true; + bool allow_tensor_metadata_change_ : 1; // we decide to keep reserved_ and it will // live in Tensor after the split // The logic is that if Extend() or ReserveSpace() were ever called, // then subsequent Resize()s will not free up Storage. - bool reserved_ = false; - + bool reserved_ : 1; }; // Note [TensorImpl size constraints] @@ -1781,13 +1791,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // strides SmallVector (pre-allocated 4) // storage offset // numel -// data type pointer +// data type // (optional) device // tensor type id // miscellaneous bitfield // static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * 31, + sizeof(TensorImpl) == sizeof(int64_t) * 29, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); } // namespace c10 diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 10342908a49b..c8c7f058513d 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -376,15 +377,24 @@ struct C10_API TensorOptions { /// device guard. /// TensorOptions merge_in(TensorOptions options) const noexcept { - TensorOptions r = options; - if (!r.has_device()) r.set_device(device_opt()); - if (!r.has_dtype()) r.set_dtype(dtype_opt()); - if (!r.has_layout()) r.set_layout(layout_opt()); + TensorOptions merged = *this; + if (options.has_device()) merged.set_device(options.device_opt()); + if (options.has_dtype()) merged.set_dtype(options.dtype_opt()); + if (options.has_layout()) merged.set_layout(options.layout_opt()); // NB: requires grad is right biased; not a logical AND/OR! - if (!r.has_requires_grad()) r.set_requires_grad(requires_grad_opt()); - if (!r.has_pinned_memory()) r.set_pinned_memory(pinned_memory_opt()); - if (!r.has_memory_format()) r.set_memory_format(memory_format_opt()); - return r; + if (options.has_requires_grad()) merged.set_requires_grad(options.requires_grad_opt()); + if (options.has_pinned_memory()) merged.set_pinned_memory(options.pinned_memory_opt()); + if (options.has_memory_format()) merged.set_memory_format(options.memory_format_opt()); + return merged; + } + + // TODO remove after TensorOptions rationalization + TensorOptions merge_memory_format(c10::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(*optional_memory_format); + } + return merged; } // Resolves the tensor type set specified by the current construction axes. @@ -492,8 +502,8 @@ struct C10_API TensorOptions { // NB: We didn't use c10::optional here, because then we can't pack // the has_***_ boolean fields. - caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit Device device_ = at::kCPU; // 32-bit + caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 16-bit Layout layout_ = at::kStrided; // 8-bit MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 9f1cb93c10eb..26122ed305e2 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -28,8 +28,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { private: UndefinedTensorImpl(); static UndefinedTensorImpl _singleton; -public: - friend struct UndefinedType; }; } // namespace c10 diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 393826f75a03..457331f4a00d 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -43,12 +43,9 @@ static constexpr int kStreamsPerPoolBits = 5; static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking; -// Note: stream priority is not supported by HIP // Note: lower numbers are higher priorities, zero is default priority -#ifndef __HIP_PLATFORM_HCC__ static int kHighPriority = -1; static int kLowPriority = 0; -#endif // __HIP_PLATFORM_HCC__ // Default streams static std::once_flag init_flag; @@ -229,17 +226,10 @@ static void initDeviceStreamState(DeviceIndex device_index) { lowpri_stream.device_index = device_index; hipri_stream.device_index = device_index; -#ifndef __HIP_PLATFORM_HCC__ C10_CUDA_CHECK(cudaStreamCreateWithPriority( &lowpri_stream.stream, kDefaultFlags, kLowPriority)); C10_CUDA_CHECK(cudaStreamCreateWithPriority( &hipri_stream.stream, kDefaultFlags, kHighPriority)); -#else - C10_CUDA_CHECK( - cudaStreamCreateWithFlags(&lowpri_stream.stream, kDefaultFlags)); - C10_CUDA_CHECK( - cudaStreamCreateWithFlags(&hipri_stream.stream, kDefaultFlags)); -#endif // __HIP_PLATFORM_HCC__ } } diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index d9bc553aa263..e82565443450 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -120,14 +120,10 @@ class C10_CUDA_API CUDAStream { } int priority() const { - #ifndef __HIP_PLATFORM_HCC__ DeviceGuard guard{stream_.device()}; int priority = 0; C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); return priority; - #else - AT_ERROR("cuStreamGetPriority with HIP is not supported"); - #endif } /// Explicit conversion to cudaStream_t. @@ -154,7 +150,6 @@ class C10_CUDA_API CUDAStream { } static std::tuple priority_range() { - #ifndef __HIP_PLATFORM_HCC__ // Note: this returns the range of priority **supported by PyTorch**, not // the range of priority **supported by CUDA**. The former is a subset of // the latter. Curently PyTorch only supports 0 and -1, which are "low" and @@ -165,9 +160,6 @@ class C10_CUDA_API CUDAStream { TORCH_INTERNAL_ASSERT(least_priority >= 0, "Unexpected CUDA stream priority range"); TORCH_INTERNAL_ASSERT(greatest_priority <= -1, "Unexpected CUDA stream priority range"); return std::make_tuple(0, -1); - #else - AT_ERROR("cuDeviceGetStreamPriorityRange with HIP is not supported"); - #endif } // Deleted for now; use CUDAEvent::block instead diff --git a/c10/util/Half.h b/c10/util/Half.h index 8f8dd3467367..01562acea704 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -328,7 +328,9 @@ namespace detail { const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + return static_cast( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); } } // namespace detail @@ -372,10 +374,11 @@ struct alignas(4) complex { Half imag() const { return imag_; } - inline complex(c10::complex value) - : real_(value.real()), imag_(value.imag()) {} - inline complex(c10::complex value) + explicit inline complex(c10::complex value) : real_(value.real()), imag_(value.imag()) {} + explicit inline complex(c10::complex value) + : real_(static_cast(value.real())), + imag_(static_cast(value.imag())) {} inline operator c10::complex() const { return {real_, imag_}; } diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 453196510aa8..2f44b9af0714 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -5,6 +5,11 @@ #include #include +namespace pybind11 { + template + class class_; +} + namespace c10 { class intrusive_ptr_target; namespace raw { @@ -14,6 +19,9 @@ namespace raw { namespace intrusive_ptr { inline void incref(intrusive_ptr_target * self); } + + // constructor tag used by intrusive_ptr constructors + struct DontIncreaseRefcount {}; } /** * intrusive_ptr is an alternative to shared_ptr that has better @@ -182,6 +190,16 @@ class intrusive_ptr final { friend class intrusive_ptr; friend class weak_intrusive_ptr; + // Make pybind11::class_ be a friend class of intrusive_ptr, so that custom smart + // holder in pybind11 could access the private constructor of intrusive_ptr(T*) + // which took the ownership of the object. + // This is required by customer holder macro PYBIND11_DECLARE_HOLDER_TYPE, where + // it uses intrusive_ptr(TTarget*) to initialize and take ownership of the object. + // For details, see + // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers + template + friend class pybind11::class_; + void retain_() { if (target_ != NullType::singleton()) { size_t new_refcount = ++target_->refcount_; @@ -207,16 +225,37 @@ class intrusive_ptr final { target_ = NullType::singleton(); } + + // raw pointer constructors are not public because we shouldn't make intrusive_ptr + // out of raw pointers except from inside the make_intrusive(), reclaim() and + // weak_intrusive_ptr::lock() implementations. + // This constructor will not increase the ref counter for you. - // This is not public because we shouldn't make intrusive_ptr out of raw - // pointers except from inside the make_intrusive() and - // weak_intrusive_ptr::lock() implementations - explicit intrusive_ptr(TTarget* target) noexcept : target_(target) {} + // We use the tagged dispatch mechanism to explicitly mark this constructor + // to not increase the refcount + explicit intrusive_ptr(TTarget* target, raw::DontIncreaseRefcount) noexcept + : target_(target) {} + + // This constructor will increase the ref counter for you. + // This constructor will be used by the make_intrusive(), and also pybind11, which + // wrap the intrusive_ptr holder around the raw pointer and incref correspondingly + // (pybind11 requires raw pointer constructor to incref by default). + explicit intrusive_ptr(TTarget* target) + : intrusive_ptr(target, raw::DontIncreaseRefcount{}) { + if (target_ != NullType::singleton()) { + // We can't use retain_(), because we also have to increase weakcount + // and because we allow raising these values from 0, which retain_() + // has an assertion against. + ++target_->refcount_; + ++target_->weakcount_; + } + } public: using element_type = TTarget; - intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton()) {} + intrusive_ptr() noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} intrusive_ptr(intrusive_ptr&& rhs) noexcept : target_(rhs.target_) { rhs.target_ = NullType::singleton(); @@ -347,19 +386,17 @@ class intrusive_ptr final { * passed in *must* have been created using intrusive_ptr::release(). */ static intrusive_ptr reclaim(TTarget* owning_ptr) { - return intrusive_ptr(owning_ptr); + return intrusive_ptr(owning_ptr, raw::DontIncreaseRefcount{}); } + /** + * Allocate a heap object with args and wrap it inside a intrusive_ptr and + * incref. This is a helper function to let make_intrusive() access private + * intrusive_ptr constructors. + */ template static intrusive_ptr make(Args&&... args) { - auto result = intrusive_ptr(new TTarget(std::forward(args)...)); - // We can't use retain_(), because we also have to increase weakcount - // and because we allow raising these values from 0, which retain_() - // has an assertion against. - ++result.target_->refcount_; - ++result.target_->weakcount_; - - return result; + return intrusive_ptr(new TTarget(std::forward(args)...)); } /** @@ -590,17 +627,18 @@ class weak_intrusive_ptr final { intrusive_ptr lock() const noexcept { if (expired()) { - return intrusive_ptr(NullType::singleton()); + return intrusive_ptr(); } else { auto refcount = target_->refcount_.load(); do { if (refcount == 0) { // Object already destructed, no strong references left anymore. // Return nullptr. - return intrusive_ptr(NullType::singleton()); + return intrusive_ptr(); } } while (!target_->refcount_.compare_exchange_weak(refcount, refcount + 1)); - return intrusive_ptr(target_); + return intrusive_ptr( + target_, raw::DontIncreaseRefcount{}); } } diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 8def126c29aa..2c4fbf8a501b 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -14,9 +14,10 @@ #define LLVM_SUPPORT_MATHEXTRAS_H #include - #include #include #include + #include + #include #include #include #include @@ -547,26 +548,26 @@ /// (32 bit edition.) /// Ex. Log2_32(32) == 5, Log2_32(1) == 0, Log2_32(0) == -1, Log2_32(6) == 2 inline unsigned Log2_32(uint32_t Value) { - return 31 - countLeadingZeros(Value); + return static_cast(31 - countLeadingZeros(Value)); } /// Return the floor log base 2 of the specified value, -1 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64(uint64_t Value) { - return 63 - countLeadingZeros(Value); + return static_cast(63 - countLeadingZeros(Value)); } /// Return the ceil log base 2 of the specified value, 32 if the value is zero. /// (32 bit edition). /// Ex. Log2_32_Ceil(32) == 5, Log2_32_Ceil(1) == 0, Log2_32_Ceil(6) == 3 inline unsigned Log2_32_Ceil(uint32_t Value) { - return 32 - countLeadingZeros(Value - 1); + return static_cast(32 - countLeadingZeros(Value - 1)); } /// Return the ceil log base 2 of the specified value, 64 if the value is zero. /// (64 bit edition.) inline unsigned Log2_64_Ceil(uint64_t Value) { - return 64 - countLeadingZeros(Value - 1); + return static_cast(64 - countLeadingZeros(Value - 1)); } /// Return the greatest common divisor of the values using Euclid's algorithm. @@ -589,6 +590,7 @@ /// This function takes a 32-bit integer and returns the bit equivalent float. inline float BitsToFloat(uint32_t Bits) { + //TODO: Use bit_cast once C++20 becomes available. float F; static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes"); memcpy(&F, &Bits, sizeof(Bits)); diff --git a/c10/util/math_compat.h b/c10/util/math_compat.h index 7d1a7b643850..b522cd26f989 100644 --- a/c10/util/math_compat.h +++ b/c10/util/math_compat.h @@ -59,6 +59,14 @@ namespace std { throw std::runtime_error("std::hypot is not implemented on older Android"); } + // TODO: this function needs to be implemented and tested. Currently just throw an error. + inline float igamma(float x, float y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + inline double igamma(double x, double y) { + throw std::runtime_error("igamma is not implemented on older Android"); + } + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float nextafter(float x, float y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); @@ -66,7 +74,7 @@ namespace std { inline double nextafter(double x, double y) { throw std::runtime_error("std::nextafter is not implemented on older Android"); } - + // TODO: this function needs to be implemented and tested. Currently just throw an error. inline float exp2(float x) { throw std::runtime_error("std::exp2 is not implemented on older Android"); diff --git a/c10/util/typeid.cpp b/c10/util/typeid.cpp index e2070a1584a2..f3fe048b4cca 100644 --- a/c10/util/typeid.cpp +++ b/c10/util/typeid.cpp @@ -14,42 +14,41 @@ namespace detail { C10_EXPORT void _ThrowRuntimeTypeLogicError(const string& msg) { // In earlier versions it used to be std::abort() but it's a bit hard-core // for a library - AT_ERROR(msg); + TORCH_CHECK(false, msg); } +} // namespace detail +[[noreturn]] void TypeMeta::error_unsupported_typemeta(caffe2::TypeMeta dtype) { + TORCH_CHECK(false, "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); +} -} // namespace detail +// see TypeMeta::addTypeMetaData +std::atomic TypeMeta::nextTypeIndex(NumScalarTypes); -template <> -EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept { - static constexpr detail::TypeMetaData singleton = detail::TypeMetaData( - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - TypeIdentifier::uninitialized(), - "nullptr (uninitialized)"); - return &singleton; +// fixed length array of TypeMetaData instances +detail::TypeMetaData* TypeMeta::typeMetaDatas() { + static detail::TypeMetaData instances[MaxTypeIndex + 1] = { +#define SCALAR_TYPE_META(T, name) \ + /* ScalarType::name */ \ + detail::TypeMetaData( \ + sizeof(T), \ + detail::_PickNew(), \ + detail::_PickPlacementNew(), \ + detail::_PickCopy(), \ + detail::_PickPlacementDelete(), \ + detail::_PickDelete(), \ + TypeIdentifier::Get(), \ + c10::util::get_fully_qualified_type_name()), +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_META) +#undef SCALAR_TYPE_META + // The remainder of the array is padded with TypeMetaData blanks. + // The first of these is the entry for ScalarType::Undefined. + // The rest are consumed by CAFFE_KNOWN_TYPE entries. + }; + return instances; } -CAFFE_KNOWN_TYPE(uint8_t) -CAFFE_KNOWN_TYPE(int8_t) -CAFFE_KNOWN_TYPE(int16_t) -CAFFE_KNOWN_TYPE(int) -CAFFE_KNOWN_TYPE(int64_t) -CAFFE_KNOWN_TYPE(at::Half) -CAFFE_KNOWN_TYPE(float) -CAFFE_KNOWN_TYPE(double) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -CAFFE_KNOWN_TYPE(c10::complex) -// 11 = undefined type id -// 12 = Tensor (defined in tensor.cc) CAFFE_KNOWN_TYPE(std::string) -CAFFE_KNOWN_TYPE(bool) CAFFE_KNOWN_TYPE(uint16_t) CAFFE_KNOWN_TYPE(char) CAFFE_KNOWN_TYPE(std::unique_ptr) @@ -79,15 +78,11 @@ using _guard_long_unique = std::conditional_t< _guard_long_unique_dummy, T>; } // namespace detail + CAFFE_KNOWN_TYPE(detail::_guard_long_unique); CAFFE_KNOWN_TYPE(detail::_guard_long_unique>) CAFFE_KNOWN_TYPE(float*) CAFFE_KNOWN_TYPE(at::Half*) -CAFFE_KNOWN_TYPE(c10::qint8) -CAFFE_KNOWN_TYPE(c10::quint8) -CAFFE_KNOWN_TYPE(c10::qint32) -CAFFE_KNOWN_TYPE(at::BFloat16) -CAFFE_KNOWN_TYPE(c10::quint4x2) } // namespace caffe2 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 51833fb545ad..5bdbdc4271df 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -21,18 +21,14 @@ #include #include #include -#include #include #include #include #include -#include -#include -#include -#include -#include #include +#include + /* * TypeIdentifier is a small type containing an id. * Types must be registered using CAFFE_KNOWN_TYPE() for them to have a type id. @@ -67,7 +63,7 @@ namespace caffe2 { */ class C10_API TypeIdentifier final : public at::IdWrapper { - public: +public: friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId); friend constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs); @@ -87,9 +83,8 @@ class C10_API TypeIdentifier final return TypeIdentifier(c10::util::type_index{0}); } - private: +private: constexpr explicit TypeIdentifier(c10::util::type_index id) : IdWrapper(id) {} - friend class TypeMeta; // TODO Is this friend an issue? }; // Allow usage in std::map / std::set @@ -126,7 +121,16 @@ struct TypeMetaData final { using PlacementDelete = void(void*, size_t); using Delete = void(void*); - TypeMetaData() = delete; + constexpr TypeMetaData() noexcept + : itemsize_(0), + new_(nullptr), + placementNew_(nullptr), + copy_(nullptr), + placementDelete_(nullptr), + delete_(nullptr), + id_(TypeIdentifier::uninitialized()), + name_("nullptr (uninitialized)") {} + constexpr TypeMetaData( size_t itemsize, New* newFn, @@ -136,14 +140,14 @@ struct TypeMetaData final { Delete* deleteFn, TypeIdentifier id, c10::string_view name) noexcept - : itemsize_(itemsize), - new_(newFn), - placementNew_(placementNew), - copy_(copy), - placementDelete_(placementDelete), - delete_(deleteFn), - id_(id), - name_(name) {} + : itemsize_(itemsize), + new_(newFn), + placementNew_(placementNew), + copy_(copy), + placementDelete_(placementDelete), + delete_(deleteFn), + id_(id), + name_(name) {} size_t itemsize_; New* new_; @@ -294,25 +298,24 @@ inline constexpr TypeMetaData::Delete* _PickDelete() noexcept { return &_Delete; } -template -inline C10_TYPENAME_CONSTEXPR TypeMetaData _makeTypeMetaDataInstance() { - C10_HOST_CONSTEXPR_VAR auto typeId = TypeIdentifier::Get(); - C10_TYPENAME_CONSTEXPR auto typeName = c10::util::get_fully_qualified_type_name(); - - return {sizeof(T), - _PickNew(), - _PickPlacementNew(), - _PickCopy(), - _PickPlacementDelete(), - _PickDelete(), - typeId, - typeName}; -} - class _Uninitialized final {}; } // namespace detail +// +// note: this is outside TypeMeta bc gcc seems to have trouble +// with scalarTypeItemSizes as a constexpr static member used by +// a public inline instance method +// + +// item sizes for TypeMeta::itemsize() fast path +static constexpr size_t scalarTypeItemSizes[NumScalarTypes] = { +#define SCALAR_TYPE_SIZE(T, name) sizeof(T), + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_SIZE) +#undef SCALAR_TYPE_SIZE + 0, // Undefined +}; + /** * TypeMeta is a thin class that allows us to store the type of a container such * as a blob, or the data type of a tensor, with a unique run-time id. It also @@ -338,17 +341,22 @@ class C10_API TypeMeta final { TypeMeta(const TypeMeta& src) noexcept = default; /** - * Assignment operator. + * Assignment operators. */ TypeMeta& operator=(const TypeMeta& src) noexcept = default; TypeMeta(TypeMeta&& rhs) noexcept = default; - private: + inline TypeMeta& operator=(ScalarType scalar_type) noexcept { + index_ = static_cast(scalar_type); + return *this; + } + +private: // TypeMeta can only be created by Make, making sure that we do not // create incorrectly mixed up TypeMeta objects. - explicit TypeMeta(const detail::TypeMetaData* data) noexcept - : data_(data) { + explicit TypeMeta(const uint16_t index) noexcept + : index_(index) { } public: @@ -356,48 +364,66 @@ class C10_API TypeMeta final { * Returns the type id. */ TypeIdentifier id() const noexcept { - return data_->id_; + return data().id_; + } + /** + * true if we represent some ScalarType type + */ + inline bool isScalarType() const noexcept { + return index_ < NumScalarTypes; + } + /** + * true if we represent ScalarType scalar_type + */ + inline bool isScalarType(ScalarType scalar_type) const noexcept { + return index_ == static_cast(scalar_type); } /** * Returns the size of the item. */ - size_t itemsize() const noexcept { - return data_->itemsize_; + inline size_t itemsize() const noexcept { + if (C10_LIKELY(isScalarType())) { + return scalarTypeItemSizes[index_]; + } + return data().itemsize_; } + /** + * Returns the new function pointer for individual items. + */ New* newFn() const noexcept { - return data_->new_; + return data().new_; } /** * Returns the placement new function pointer for individual items. */ PlacementNew* placementNew() const noexcept { - return data_->placementNew_; + return data().placementNew_; } /** * Returns the typed copy function pointer for individual iterms. */ Copy* copy() const noexcept { - return data_->copy_; + return data().copy_; } /** * Returns the destructor function pointer for individual items. */ PlacementDelete* placementDelete() const noexcept { - return data_->placementDelete_; + return data().placementDelete_; } Delete* deleteFn() const noexcept { - return data_->delete_; + return data().delete_; } /** * Returns a printable name for the type. */ c10::string_view name() const noexcept { - return data_->name_; + return data().name_; } friend bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept; + const TypeMeta lhs, + const TypeMeta rhs) noexcept; template bool Match() const noexcept { @@ -412,7 +438,7 @@ class C10_API TypeMeta final { } template - static C10_TYPENAME_CONSTEXPR c10::string_view TypeName() noexcept { + static c10::string_view TypeName() noexcept { return c10::util::get_fully_qualified_type_name(); } @@ -437,35 +463,105 @@ class C10_API TypeMeta final { #pragma GCC diagnostic ignored "-Wunknown-warning-option" #pragma GCC diagnostic ignored "-Wundefined-var-template" #endif - return TypeMeta(_typeMetaDataInstance()); + return TypeMeta(_typeMetaData()); #ifndef _MSC_VER #pragma GCC diagnostic pop #endif } - private: - const detail::TypeMetaData* data_; + /** + * convert ScalarType enum values to TypeMeta handles + */ + static inline caffe2::TypeMeta fromScalarType(ScalarType scalar_type) { + const size_t index = static_cast(scalar_type); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index < NumScalarTypes, + "Unrecognized Scalartype ", scalar_type, " (please report this error)"); + return TypeMeta(index); + } + + /** + * convert TypeMeta handles to ScalarType enum values + */ + inline ScalarType toScalarType() { + if (C10_LIKELY(isScalarType())) { + return static_cast(index_); + } + error_unsupported_typemeta(*this); + } + +private: + [[noreturn]] static void error_unsupported_typemeta(caffe2::TypeMeta dtype); + + // hard limit number of registered types + // note: constexpr provokes Windows compilation error "member may not be initialized" + // static constexpr size_t MaxTypeIndex = UINT8_MAX; + #define MaxTypeIndex UINT8_MAX + + static std::atomic nextTypeIndex; + + static detail::TypeMetaData* typeMetaDatas(); template - C10_API static const detail::TypeMetaData* _typeMetaDataInstance() noexcept; + static uint16_t addTypeMetaData() { + const uint16_t index = nextTypeIndex++; + TORCH_CHECK(index <= MaxTypeIndex, + "Maximum number of CAFFE_KNOWN_TYPE declarations has been exceeded. ", + "Please report this issue."); + typeMetaDatas()[index] = detail::TypeMetaData{ + sizeof(T), + detail::_PickNew(), + detail::_PickPlacementNew(), + detail::_PickCopy(), + detail::_PickPlacementDelete(), + detail::_PickDelete(), + TypeIdentifier::Get(), + c10::util::get_fully_qualified_type_name()}; + return index; + } + + // specializations return indexes into typeMetaDataInstances() + template + C10_API static uint16_t _typeMetaData() noexcept; + + // + // TypeMeta just wraps this index + // + + uint16_t index_; + + inline const detail::TypeMetaData& data() const { + return typeMetaDatas()[index_]; + } }; +// specializations of TypeMeta::_typeMetaData for ScalarType types + +#define DEFINE_SCALAR_METADATA_INSTANCE(T, name) \ + template <> \ + constexpr uint16_t TypeMeta::_typeMetaData() noexcept { \ + return static_cast(ScalarType::name); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_METADATA_INSTANCE) +#undef DEFINE_SCALAR_METADATA_INSTANCE + template <> -C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance< - detail::_Uninitialized>() noexcept; +C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData() noexcept { + return static_cast(ScalarType::Undefined); +} inline TypeMeta::TypeMeta() noexcept - : data_(_typeMetaDataInstance()) { + : index_(_typeMetaData()) { } inline bool operator==( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { - return (lhs.data_ == rhs.data_); + const TypeMeta lhs, + const TypeMeta rhs) noexcept { + return (lhs.index_ == rhs.index_); } inline bool operator!=( - const TypeMeta& lhs, - const TypeMeta& rhs) noexcept { + const TypeMeta lhs, + const TypeMeta rhs) noexcept { return !operator==(lhs, rhs); } @@ -500,13 +596,11 @@ inline std::ostream& operator<<( #define EXPORT_IF_NOT_GCC #endif -#define CAFFE_KNOWN_TYPE(T) \ - template <> \ - EXPORT_IF_NOT_GCC const detail::TypeMetaData* \ - TypeMeta::_typeMetaDataInstance() noexcept { \ - static C10_TYPENAME_CONSTEXPR detail::TypeMetaData singleton = \ - detail::_makeTypeMetaDataInstance(); \ - return &singleton; \ +#define CAFFE_KNOWN_TYPE(T) \ + template <> \ + EXPORT_IF_NOT_GCC uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ } } // namespace caffe2 diff --git a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h index 4aef84663adc..ddeea5d5f56c 100644 --- a/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h +++ b/caffe2/contrib/fakelowp/layernorm_fp16_fake_op.h @@ -189,7 +189,7 @@ class LayerNormFakeFp16Op final : public Operator { int Nout = X.numel(); std::vector inv_scalev(Nout, inv_scale); - std::vector offsetv(Nout, Y_offset - 128.0); + std::vector offsetv(Nout, Y_offset); uint8_t* Y_uint8_data = Y_int8->t.template mutable_data(); fake_fp16::fma_fp16(Nout, Y_fp16.data(), inv_scalev.data(), offsetv.data()); @@ -200,7 +200,6 @@ class LayerNormFakeFp16Op final : public Operator { for (int i = 0; i < Nout; i++) { float halfRes = offsetv[i]; halfRes = round(halfRes); - halfRes = halfRes + 128.0; if (std::isinf(halfRes)) { if (halfRes > 0) { halfRes = qmax; diff --git a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py index 9ff0986116b6..5129a38c5241 100644 --- a/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py +++ b/caffe2/contrib/fakelowp/test/test_layernorm_nnpi_fp16.py @@ -1,8 +1,3 @@ - - - - - import numpy as np import caffe2.python.fakelowp.init_shared_libs # noqa from caffe2.proto import caffe2_pb2 diff --git a/caffe2/contrib/nccl/cuda_nccl_gpu.cc b/caffe2/contrib/nccl/cuda_nccl_gpu.cc index 31cd55d08578..ef2b9ab37ea0 100644 --- a/caffe2/contrib/nccl/cuda_nccl_gpu.cc +++ b/caffe2/contrib/nccl/cuda_nccl_gpu.cc @@ -28,13 +28,8 @@ class NCCLContext { // get stream priorities int lo_pri, hi_pri; CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri)); -#ifndef __HIP_PLATFORM_HCC__ CUDA_ENFORCE(cudaStreamCreateWithPriority( &streams_[i], cudaStreamNonBlocking, hi_pri)); -#else - CUDA_ENFORCE(cudaStreamCreateWithFlags( - &streams_[i], cudaStreamNonBlocking)); -#endif // __HIP_PLATFORM_HCC__ CUDA_ENFORCE(cudaEventCreateWithFlags( &events_[i], cudaEventDefault | cudaEventDisableTiming)); } diff --git a/caffe2/contrib/opencl/context.h b/caffe2/contrib/opencl/context.h index ce788a39a7cd..15bfda2203f0 100644 --- a/caffe2/contrib/opencl/context.h +++ b/caffe2/contrib/opencl/context.h @@ -59,7 +59,7 @@ class OpenCLContext final { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "OpenCLContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/core/context.h b/caffe2/core/context.h index f3f4a9138ce1..b0e99ef1e59e 100644 --- a/caffe2/core/context.h +++ b/caffe2/core/context.h @@ -131,7 +131,7 @@ class CAFFE2_API CPUContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h index bad6872819de..036ac98fdc91 100644 --- a/caffe2/core/context_base.h +++ b/caffe2/core/context_base.h @@ -104,7 +104,7 @@ class CAFFE2_API BaseContext { } void CopyItemsSameDevice( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -117,7 +117,7 @@ class CAFFE2_API BaseContext { } void CopyItemsFromCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { @@ -130,7 +130,7 @@ class CAFFE2_API BaseContext { } void CopyItemsToCPU( - const caffe2::TypeMeta& meta, + const caffe2::TypeMeta meta, size_t n, const void* src, void* dst) { diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index c0930b1a0e61..7406132f8788 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -279,7 +279,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { CAFFE_ENFORCE(!meta.copy(), "CUDAContext requires fundamental types."); CopyBytes(n * meta.itemsize(), src, dst); } diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index ea9ae7892a23..a4abc97f73e4 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -1246,7 +1246,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta& meta) { \ + static bool call(Op* op, const TypeMeta meta) { \ static_assert( \ !std::is_same::value, \ "GenericTensorImplementation must be the last in TensorTypes list"); \ @@ -1269,7 +1269,7 @@ struct DispatchHelper, ExtraArgs...> { template \ struct DispatchHelper, ExtraArgs...> { \ template \ - static bool call(Op* /* unused */, const TypeMeta& meta) { \ + static bool call(Op* /* unused */, const TypeMeta meta) { \ CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ } \ template \ @@ -1287,7 +1287,7 @@ struct DispatchHelper, ExtraArgs...> { TensorTypes, \ ExtraArgs...> { \ template \ - static bool call(Op* op, const TypeMeta&) { \ + static bool call(Op* op, const TypeMeta) { \ return op->template DoRunWithOtherType(); \ } \ template \ diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc index 97c309c078e4..06e27ef5be54 100644 --- a/caffe2/core/plan_executor.cc +++ b/caffe2/core/plan_executor.cc @@ -133,8 +133,8 @@ std::function getContinuationTest( // if the blob doesn't exist or is not initialized, return false inline bool getShouldStop(const Blob* b) { if (!b || - b->meta().id() == - TypeIdentifier::uninitialized()) { // not exist or uninitialized + b->meta() == + ScalarType::Undefined) { // not exist or uninitialized return false; } diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 27f8b471b71b..83df5306e177 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -299,14 +299,14 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( void* src, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes = 0, MemoryDeleter d = nullptr) const { CAFFE_ENFORCE_WITH_CALLER( impl_->is_contiguous(), "Right now ShareExternalPointer is only supported for contiguous Tensor."); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "To share with a raw external pointer you need to pass in an " "initialized data_type(TypeMeta)."); impl_.get()->ShareExternalPointer( @@ -315,7 +315,7 @@ class CAFFE2_API Tensor final { void ShareExternalPointer( at::DataPtr&& data_ptr, - const TypeMeta& data_type, + const TypeMeta data_type, size_t nbytes) { impl_.get()->ShareExternalPointer(std::move(data_ptr), data_type, nbytes); } @@ -342,7 +342,7 @@ class CAFFE2_API Tensor final { return impl_.get()->data(); } - inline void* raw_mutable_data(const TypeMeta& meta) const { + inline void* raw_mutable_data(const TypeMeta meta) const { return impl_.get()->raw_mutable_data(meta); } @@ -358,7 +358,7 @@ class CAFFE2_API Tensor final { inline void* raw_mutable_data() const { const auto& data_type = impl_->dtype(); CAFFE_ENFORCE_WITH_CALLER( - data_type.id() != caffe2::TypeIdentifier::uninitialized(), + data_type != ScalarType::Undefined, "Calling raw_mutable_data() without meta, but the current meta is " "of unknown type."); return raw_mutable_data(data_type); @@ -469,7 +469,7 @@ class CAFFE2_API Tensor final { /** * Returns the TypeMeta object associated with the current data type. */ - inline const TypeMeta& dtype() const { + inline const TypeMeta dtype() const { return impl_->dtype(); } @@ -477,7 +477,7 @@ class CAFFE2_API Tensor final { * (To be deprecated) Returns the TypeMeta object associated with the current * data type. */ - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return impl_->dtype(); } diff --git a/caffe2/core/types.cc b/caffe2/core/types.cc index d1007fe76e86..c738fc50a288 100644 --- a/caffe2/core/types.cc +++ b/caffe2/core/types.cc @@ -8,7 +8,7 @@ namespace caffe2 { -TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { +TensorProto::DataType TypeMetaToDataType(const TypeMeta meta) { static_assert( sizeof(int) == 4, "int in this compiler does not equal to 4 bytes."); static std::map data_type_map{ @@ -36,7 +36,7 @@ TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) { it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second); } -const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) { +const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt) { static std::map type_meta_map{ {TensorProto_DataType_FLOAT, TypeMeta::Make()}, {TensorProto_DataType_INT32, TypeMeta::Make()}, diff --git a/caffe2/core/types.h b/caffe2/core/types.h index c0e8d7bbfb3d..5dda5a5e0810 100644 --- a/caffe2/core/types.h +++ b/caffe2/core/types.h @@ -47,10 +47,10 @@ inline int32_t GetDimFromOrderString(const std::string& str) { inline constexpr char NameScopeSeparator() { return '/'; } // From TypeMeta to caffe2::DataType protobuffer enum. -CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta); +CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta meta); // From caffe2::DataType protobuffer enum to TypeMeta -CAFFE2_API const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt); +CAFFE2_API const TypeMeta DataTypeToTypeMeta(const TensorProto::DataType& dt); } // namespace caffe2 diff --git a/caffe2/ideep/utils/ideep_context.h b/caffe2/ideep/utils/ideep_context.h index 823b4bec16bd..d0f1207a08f6 100644 --- a/caffe2/ideep/utils/ideep_context.h +++ b/caffe2/ideep/utils/ideep_context.h @@ -91,7 +91,7 @@ class IDEEPContext final : public BaseContext { template inline void - CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) { + CopyItems(const TypeMeta meta, size_t n, const void* src, void* dst) { if (meta.copy()) { meta.copy()(src, dst, n); } else { diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 95fcfc1ab923..c311ad23e4ed 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -407,7 +407,7 @@ class UnPackRecordsOp : public Operator { // Precomputer the output sizes to avoid resizing std::vector> outputDims(numTensors); - std::vector metas(numTensors); + std::vector metas(numTensors); CAFFE_ENFORCE( numRows > 0 || InputSize() > 1, @@ -428,7 +428,7 @@ class UnPackRecordsOp : public Operator { // Checks to ensure that dimensions/sizes match CAFFE_ENFORCE_EQ(outputDims[j].size(), input.dim()); - CAFFE_ENFORCE(*metas[j] == input.dtype()); + CAFFE_ENFORCE(metas[j] == input.dtype()); // We look from first dimension, because we concat on the first. for (int k = 1; k < input.dim(); ++k) { CAFFE_ENFORCE_EQ(input.sizes()[k], outputDims[j][k]); @@ -442,7 +442,7 @@ class UnPackRecordsOp : public Operator { std::vector destinations(numTensors); for (int i = 0; i < numTensors; ++i) { Output(i)->Resize(outputDims[i]); - destinations[i] = Output(i)->raw_mutable_data(*metas[i]); + destinations[i] = Output(i)->raw_mutable_data(metas[i]); } for (int i = 0; i < numRows; ++i) { @@ -450,7 +450,7 @@ class UnPackRecordsOp : public Operator { const auto& input = tensors[i][j]; context_.CopyItemsSameDevice( - *metas[j], + metas[j], input.numel(), input.raw_data() /* src */, destinations[j] /* dst */ @@ -468,7 +468,7 @@ class UnPackRecordsOp : public Operator { void getShapeAndMetaFromInput( const Shared2DTensorVectorPtr& inputs, std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto& inputZero = inputs->at(0); const auto numTensors = inputZero.size(); @@ -479,13 +479,13 @@ class UnPackRecordsOp : public Operator { for (int i = 0; i < numTensors; ++i) { outputDims[i] = inputZero[i].sizes().vec(); outputDims[i][0] = 0; - metas[i] = &inputZero[i].dtype(); + metas[i] = inputZero[i].dtype(); } } void getShapeAndMetaFromPrototypeBlobs( std::vector>& outputDims, - std::vector& metas) { + std::vector& metas) { const auto numTensors = fields_.size(); CAFFE_ENFORCE_EQ(numTensors, InputSize() - 1); CAFFE_ENFORCE_EQ(numTensors, OutputSize()); @@ -493,7 +493,7 @@ class UnPackRecordsOp : public Operator { const auto& input = Input(i + 1); outputDims[i] = input.sizes().vec(); outputDims[i][0] = 0; - metas[i] = &input.dtype(); + metas[i] = input.dtype(); } } diff --git a/caffe2/operators/dataset_ops.h b/caffe2/operators/dataset_ops.h index 70a294e14136..fc890014dbb2 100644 --- a/caffe2/operators/dataset_ops.h +++ b/caffe2/operators/dataset_ops.h @@ -146,7 +146,7 @@ class TreeWalker { return size; } - inline const TypeMeta& meta() const { + inline const TypeMeta meta() const { return walker_.input(fieldId_).dtype(); } diff --git a/caffe2/operators/index_ops.h b/caffe2/operators/index_ops.h index 890753caf2fe..2f5705cb4c26 100644 --- a/caffe2/operators/index_ops.h +++ b/caffe2/operators/index_ops.h @@ -18,7 +18,7 @@ using int64_tValue = int64_t; struct IndexBase { public: - IndexBase(int64_tValue maxElements, const TypeMeta& type) + IndexBase(int64_tValue maxElements, const TypeMeta type) : maxElements_{maxElements}, meta_(type), frozen_{false} {} void Freeze() { @@ -35,7 +35,7 @@ struct IndexBase { virtual ~IndexBase() {} - const TypeMeta& Type() const { + const TypeMeta Type() const { return meta_; } diff --git a/caffe2/operators/numpy_tile_op.h b/caffe2/operators/numpy_tile_op.h index 8a39b40df0f8..ac9886ec503a 100644 --- a/caffe2/operators/numpy_tile_op.h +++ b/caffe2/operators/numpy_tile_op.h @@ -92,7 +92,7 @@ class NumpyTileOp : public Operator { private: void DoTile( - const TypeMeta& meta, + const TypeMeta meta, int item_size, int outer_dim, int inner_dim, diff --git a/caffe2/operators/self_binning_histogram_op.cc b/caffe2/operators/self_binning_histogram_op.cc index 8cecf0267ea3..111abd18094c 100644 --- a/caffe2/operators/self_binning_histogram_op.cc +++ b/caffe2/operators/self_binning_histogram_op.cc @@ -35,7 +35,11 @@ OPERATOR_SCHEMA(SelfBinningHistogram) "logspace_start", "A float that's used as the starting point for logarithmic spacing. " "Since logarithmic spacing cannot contain <=0 values this value will " - "be used to represent all such values."); + "be used to represent all such values.") + .Arg( + "abs", + "Apply abs() on every input value." + ); SHOULD_NOT_DO_GRADIENT(SelfBinningHistogram); } // namespace caffe2 diff --git a/caffe2/operators/self_binning_histogram_op.h b/caffe2/operators/self_binning_histogram_op.h index d29d02b2deb9..6fb6c8f14a08 100644 --- a/caffe2/operators/self_binning_histogram_op.h +++ b/caffe2/operators/self_binning_histogram_op.h @@ -19,7 +19,8 @@ class SelfBinningHistogramOp final : public Operator { bin_spacing_(this->template GetSingleArgument( "bin_spacing", "linear")), - logspace_start_(this->template GetSingleArgument("logspace_start", 1e-24)) + logspace_start_(this->template GetSingleArgument("logspace_start", 1e-24)), + abs_(this->template GetSingleArgument("abs", false)) { CAFFE_ENFORCE_GE( num_bins_, 1, "Number of bins must be greater than or equal to 1."); @@ -64,13 +65,14 @@ class SelfBinningHistogramOp final : public Operator { total_count += N; const auto* x_data = x.template data(); for (int64_t data_idx = 0; data_idx < N; data_idx++) { + const T val = this->abs_ ? abs(x_data[data_idx]) : x_data[data_idx]; if (!first_seen) { - max = x_data[data_idx]; - min = x_data[data_idx]; + max = val; + min = val; first_seen = true; } else { - max = std::max(x_data[data_idx], max); - min = std::min(x_data[data_idx], min); + max = std::max(val, max); + min = std::min(val, min); } } } @@ -130,10 +132,11 @@ class SelfBinningHistogramOp final : public Operator { const int64_t N = x.numel(); const auto* x_data = x.template data(); for (int64_t data_idx = 0; data_idx < N; data_idx++) { + const T val = this->abs_ ? abs(x_data[data_idx]) : x_data[data_idx]; const auto bisection_it = std::upper_bound( histogram_values_data, histogram_values_data + num_edges_, - x_data[data_idx]); + val); const int bisection_idx = bisection_it - histogram_values_data; if (bisection_idx > 0 && bisection_idx < num_edges_) { histogram_counts_data[bisection_idx - 1]++; @@ -156,6 +159,7 @@ class SelfBinningHistogramOp final : public Operator { int num_edges_; std::string bin_spacing_; float logspace_start_; + bool abs_; // automatically apply abs() on the input values void CheckInputs() { const auto& input_zero = Input(0); diff --git a/caffe2/operators/tile_op.cc b/caffe2/operators/tile_op.cc index 40684c50575b..b0d797fce7ff 100644 --- a/caffe2/operators/tile_op.cc +++ b/caffe2/operators/tile_op.cc @@ -71,7 +71,7 @@ bool TileOp::DoRunWithType() { // size from axis up const int inner_size = X.size_from_dim(axis); - const TypeMeta& meta = X.dtype(); + const TypeMeta meta = X.dtype(); const int item_size = X.itemsize(); const char* X_ptr = reinterpret_cast(X.raw_data()); char* Y_ptr = reinterpret_cast(Y->raw_mutable_data(meta)); diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index b691c24e984a..9abcf5ab0b86 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -59,6 +59,7 @@ REGISTER_CPU_OPERATOR(GatherRanges, GatherRangesOp); REGISTER_CPU_OPERATOR(LengthsGather, LengthsGatherOp); REGISTER_CPU_OPERATOR(LengthsToSegmentIds, LengthsToSegmentIdsOp); REGISTER_CPU_OPERATOR(LengthsToRanges, LengthsToRangesOp); +REGISTER_CPU_OPERATOR(LengthsToOffsets, LengthsToOffsetsOp); REGISTER_CPU_OPERATOR(SegmentIdsToLengths, SegmentIdsToLengthsOp); REGISTER_CPU_OPERATOR(SegmentIdsToRanges, SegmentIdsToRangesOp); REGISTER_CPU_OPERATOR(LengthsToWeights, LengthsToWeightsOp); @@ -522,20 +523,20 @@ Another output LENGTHS represents each example length within OUTPUT "LENGTHS", "1-D tensor of size N with lengths over gathered data" " for each row in a batch. sum(LENGTHS) == OUTPUT.size()") - .TensorInferenceFunction(OpSchema::NeedsAllInputShapes([]( - const OperatorDef& /* unused */, const vector& in) { - std::vector out(2); - - int total = 1; - for (auto d : in[0].dims()) { - total *= d; - } - out[0].add_dims(total); - out[0].set_data_type(in[0].data_type()); - out[1].add_dims(in[1].dims(0)); - out[1].set_data_type(in[1].data_type()); - return out; - })); + .TensorInferenceFunction(OpSchema::NeedsAllInputShapes( + [](const OperatorDef& /* unused */, const vector& in) { + std::vector out(2); + + int total = 1; + for (auto d : in[0].dims()) { + total *= d; + } + out[0].add_dims(total); + out[0].set_data_type(in[0].data_type()); + out[1].add_dims(in[1].dims(0)); + out[1].set_data_type(in[1].data_type()); + return out; + })); OPERATOR_SCHEMA(LengthsGather) .NumInputs(3) @@ -636,6 +637,30 @@ For example, `[1, 3, 0, 2]` transforms into `[[0, 1], [1, 3], [4, 0], [4, 2]]`. "ranges", "2D tensor of shape len(lengths) X 2 and the same type as `lengths`"); +OPERATOR_SCHEMA(LengthsToOffsets) + .NumInputs(1) + .NumOutputs(1) + .SetDoc(R"DOC( +Given a vector of segment lengths, returns a vector of offsets from these lengths, +which will have the same size as the input vector. Output is going to have +the same type as input. For long tensors explicit casting from int32 to int64 +might be necessary prior to this op. + +For example, `[1, 3, 0, 2]` transforms into `[0, 1, 4, 4]`. +)DOC") + .Input(0, "lengths", "1D tensor of int32 or int64 segment lengths.") + .Output(0, "offsets", "1D tensor of the same shape and type as `lengths`") + .TensorInferenceFunction([](const OperatorDef& def, + const vector& in) { + const ArgumentHelper args(def); + bool include_last_offset = + args.GetSingleArgument("include_last_offset", false); + vector out_shape(in[0].dims().begin(), in[0].dims().end()); + out_shape[0] += include_last_offset ? 1 : 0; + return vector{ + CreateTensorShape(out_shape, in[0].data_type())}; + }); + OPERATOR_SCHEMA(SegmentIdsToLengths) .NumInputs(1, 2) .NumOutputs(1) diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index a82b5666fb7b..bdc9c0bfbfd9 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -918,6 +918,45 @@ class LengthsToRangesOp : public Operator { } }; +template +class LengthsToOffsetsOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + + template + explicit LengthsToOffsetsOp(Args&&... args) + : Operator(std::forward(args)...), + include_last_offset_(this->template GetSingleArgument( + "include_last_offset", + false)) {} + + bool RunOnDevice() override { + auto& input = Input(0); + auto* output = Output(0); + auto* input_data = input.template data(); + + CAFFE_ENFORCE(input.sizes().size() == 1, "Input must be a vector."); + auto size = input.numel(); + + output->Resize(size + (include_last_offset_ ? 1 : 0)); + auto* output_data = output->template mutable_data(); + + int32_t offset = 0; + for (int i = 0; i < size; ++i) { + auto len = input_data[i]; + output_data[i] = offset; + offset += len; + } + if (include_last_offset_) { + output_data[size] = offset; + } + return true; + } + + private: + bool include_last_offset_; +}; + template class SegmentIdsToLengthsOp : public Operator { public: diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 045677f8422a..9298134f651c 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -994,6 +994,38 @@ def op_ref(x): inputs=[np.array(lengths, dtype=np.int32)], reference=op_ref) + @given( + lengths=st.lists( + st.integers(min_value=0, max_value=10), min_size=0, max_size=10 + ), + include_last_offset=st.booleans(), + **hu.gcs_cpu_only + ) + @settings(deadline=None) + def test_lengths_to_offsets(self, lengths, include_last_offset, gc, dc): + op = core.CreateOperator( + "LengthsToOffsets", + ["lengths"], + ["ranges"], + include_last_offset=include_last_offset, + ) + + def op_ref(x): + if not x.size: + arr = [x.reshape(0)] + else: + arr = [np.concatenate(([0], np.cumsum(x)[:-1]))] + if include_last_offset: + arr[0] = np.concatenate((arr[0], np.array([np.sum(x)]))) + return tuple(arr) + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=[np.array(lengths, dtype=np.int32)], + reference=op_ref, + ) + @given(prediction=hu.arrays(dims=[10, 3], elements=hu.floats(allow_nan=False, allow_infinity=False, diff --git a/caffe2/python/layers/last_n_window_collector.py b/caffe2/python/layers/last_n_window_collector.py index a16b731a2f78..5e6874b4cca0 100644 --- a/caffe2/python/layers/last_n_window_collector.py +++ b/caffe2/python/layers/last_n_window_collector.py @@ -1,10 +1,6 @@ ## @package last_n_window_collector # Module caffe2.python.layers.last_n_window_collector - - - - from caffe2.python import core, schema from caffe2.python.layers.layers import ModelLayer diff --git a/caffe2/python/operator_test/self_binning_histogram_test.py b/caffe2/python/operator_test/self_binning_histogram_test.py index 14a37872ee5a..afcf5ea57e3e 100644 --- a/caffe2/python/operator_test/self_binning_histogram_test.py +++ b/caffe2/python/operator_test/self_binning_histogram_test.py @@ -8,9 +8,10 @@ class TestSelfBinningHistogramBase(object): - def __init__(self, bin_spacing, dtype): + def __init__(self, bin_spacing, dtype, abs=False): self.bin_spacing = bin_spacing self.dtype = dtype + self.abs = abs def _check_histogram(self, arrays, num_bins, expected_values=None, expected_counts=None): # Check that sizes match and counts add up. @@ -20,28 +21,39 @@ def _check_histogram(self, arrays, num_bins, expected_values=None, expected_coun self.assertTrue(np.size(counts) == num_bins) self.assertTrue(np.sum(counts) == sum([np.size(array) for array in arrays])) - + # Check counts if expected_counts is None: # Check that counts are correct for the returned values if expected_counts is not given. expected_counts = np.zeros(num_bins, dtype='i') for array in arrays: - for i in array: + for input_val in array: + input_val = abs(input_val) if self.abs else input_val found = False for pos in range(np.size(values)): - if values[pos] > i: + if values[pos] > input_val: found = True break - self.assertTrue(found, "input array must fit inside values array") + self.assertTrue(found, f"input value must fit inside values array: " + f"input={input_val}, last_value={values[-1]}") if self.bin_spacing == "linear": - self.assertTrue(pos > 0, "first value should be the smallest") + self.assertTrue(pos > 0, + f"input should not be smaller than the first bin value: " + f"input={input_val}, 1st bin value={values[pos]}") if pos == 0: self.assertEqual(self.bin_spacing, "logarithmic") expected_counts[pos] += 1 else: expected_counts[pos - 1] += 1 self.assertTrue(np.array_equal(expected_counts, counts), f"expected:{expected_counts}\ncounts:{counts}") + # Check values if expected_values is not None: - self.assertTrue(np.array_equal(expected_values, values), f"expected:{expected_values}\ncounts:{values}") + self.assertTrue(np.allclose(expected_values, values, rtol=1e-02, atol=1e-05), + f"expected:{expected_values}\nvalues:{values}") + # Ideally, the output values are sorted in a non-decreasing order. + for idx in range(len(values) - 1): + self.assertTrue(values[idx] <= values[idx + 1]) + if self.abs: + self.assertTrue(values[0] >= 0) def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): @@ -57,6 +69,7 @@ def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): num_bins=num_bins, bin_spacing=self.bin_spacing, logspacing_start=logspacing_start, + abs=self.abs ) else: net.SelfBinningHistogram( @@ -64,6 +77,7 @@ def _run_single_op_net(self, arrays, num_bins, logspacing_start=None): ["histogram_values", "histogram_counts"], num_bins=num_bins, bin_spacing=self.bin_spacing, + abs=self.abs ) workspace.RunNetOnce(net) @@ -82,10 +96,25 @@ def test_histogram_device_consistency(self, rows, cols, gc, dc): def test_histogram_bin_to_fewer(self): X = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-2., 0.2, 2.4, 4.6, 6.8, 9.] + expected_counts = [5, 2, 2, 1, 1, 0] + else: + expected_values = [0., 1.8, 3.6, 5.4, 7.2, 9.] + expected_counts = [4, 4, 1, 1, 1, 0] + else: + expected_values = [1.e-24, 9.8e-20, 9.6e-15, 9.4e-10, 9.2e-05, 9.] + if not self.abs: + expected_counts = [5, 0, 0, 0, 6, 0] + else: + expected_counts = [3, 0, 0, 0, 8, 0] self._run_single_op_net([X], 5) self._check_histogram( [X], 6, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_bin_to_more(self): @@ -99,10 +128,20 @@ def test_histogram_bin_to_more(self): def test_histogram_bin_to_two(self): """This test roughly tests [min,max+EPSILON] and [N,0]""" X = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-2., 9.] + else: + expected_values = [0., 9.] + else: + expected_values = [1.e-24, 9.] + expected_counts = [11, 0] self._run_single_op_net([X], 1) self._check_histogram( [X], 2, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_min_max_equal(self): @@ -129,7 +168,7 @@ def test_histogram_min_max_equal(self): def test_histogram_min_max_equal_nonzero(self): X = np.array([1., 1., 1., 1., 1.], dtype=self.dtype) logspacing_start = 1e-24 - self._run_single_op_net([X], 3, 1e-24) + self._run_single_op_net([X], 3, logspacing_start) self._check_histogram( [X], 4, @@ -143,33 +182,58 @@ def test_histogram_empty_input_tensor(self): self._check_histogram( [X], 2, + expected_values=[0., 0.], + expected_counts=[0, 0] ) self._run_single_op_net([X], 10) self._check_histogram( [X], 11, + expected_values=[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + expected_counts=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) def test_histogram_multi_input(self): X1 = np.array([-2.0, -2.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 9.0], dtype=self.dtype) X2 = np.array([-5.0, -3.0, 7, 7, 0.0, 1.0, 2.0, -3.0, 4.0, 6.0, 9.0], dtype=self.dtype) + if self.bin_spacing == 'linear': + if not self.abs: + expected_values = [-5., -2.2, 0.6, 3.4, 6.2, 9.] + expected_counts = [3, 6, 5, 4, 4, 0] + else: + expected_values = [0., 1.8, 3.6, 5.4, 7.2, 9.] + expected_counts = [6, 7, 3, 4, 2, 0] + else: + expected_values = [1.e-24, 9.8e-20, 9.6e-15, 9.4e-10, 9.2e-05, 9.] + if not self.abs: + expected_counts = [9, 0, 0, 0, 13, 0] + else: + expected_counts = [4, 0, 0, 0, 18, 0] self._run_single_op_net([X1, X2], 5) self._check_histogram( [X1, X2], 6, + expected_values=expected_values, + expected_counts=expected_counts ) def test_histogram_very_small_range_for_stride_underflow(self): """Tests a large number of bins for a very small range of values. - This test uses float type. 1-e38 is very small, and with 1M bins, it + This test uses float type. 1-e302 is very small, and with 1M bins, it causes numeric underflow. This test is to show that this is handled. + + Note: this test was flaky due to how compiler and OS handls floats. + Previously, 1-e38 does not induce overflow and cuases test error for some + combinations of compiler and OS. Now 1-e302 should be small enough. """ - X = np.array([0, 1e-38], dtype='f') - self._run_single_op_net([X], 1000000) + X = np.array([0, 1e-302], dtype='f') + large_bin_number = 1000000 + self._run_single_op_net([X], large_bin_number) self._check_histogram( [X], - 1000001, + large_bin_number + 1, + expected_counts=[2] + [0] * large_bin_number # [2, 0, 0, ..., 0] ) @@ -200,6 +264,35 @@ def __init__(self, *args, **kwargs): TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='f') hu.HypothesisTestCase.__init__(self, *args, **kwargs) +class TestSelfBinningHistogramLinearWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='d', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLogarithmicWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='d', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearFloatWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='f', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLogarithmicFloatWithAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="logarithmic", dtype='f', abs=True) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearWithNoneAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='d', abs=None) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) + +class TestSelfBinningHistogramLinearFloatWithNoneAbs(TestSelfBinningHistogramBase, hu.HypothesisTestCase): + def __init__(self, *args, **kwargs): + TestSelfBinningHistogramBase.__init__(self, bin_spacing="linear", dtype='f', abs=None) + hu.HypothesisTestCase.__init__(self, *args, **kwargs) if __name__ == "__main__": global_options = ["caffe2"] diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 2923b98c565f..65a246e4a39c 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -118,7 +118,7 @@ static_assert( sizeof(int) == sizeof(int32_t), "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& meta) { +int CaffeToNumpyType(const TypeMeta meta) { #ifdef USE_NUMPY static std::map numpy_type_map{ {TypeMeta::Id(), NPY_BOOL}, @@ -143,7 +143,7 @@ int CaffeToNumpyType(const TypeMeta& meta) { #endif // USE_NUMPY } -const TypeMeta& NumpyTypeToCaffe(int numpy_type) { +const TypeMeta NumpyTypeToCaffe(int numpy_type) { #ifdef USE_NUMPY static std::map caffe_type_map{ {NPY_BOOL, TypeMeta::Make()}, diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index b8f9dbaf3719..b3926e941194 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -103,8 +103,8 @@ static_assert( "We make an assumption that int is always int32 for numpy " "type mapping."); -int CaffeToNumpyType(const TypeMeta& dtype); -const TypeMeta& NumpyTypeToCaffe(int numpy_type); +int CaffeToNumpyType(const TypeMeta dtype); +const TypeMeta NumpyTypeToCaffe(int numpy_type); class TensorFetcher : public BlobFetcherBase { public: @@ -114,7 +114,7 @@ class TensorFetcher : public BlobFetcherBase { // Checks whether the data with type `dtype` needs to be copied in the context // of `tensor` - bool NeedsCopy(const Tensor* tensor, const TypeMeta& dtype) const { + bool NeedsCopy(const Tensor* tensor, const TypeMeta dtype) const { #ifdef USE_NUMPY return tensor->GetDeviceType() != CPU || CaffeToNumpyType(dtype) == NPY_OBJECT; @@ -200,9 +200,9 @@ class TensorFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta& dtype = NumpyTypeToCaffe(npy_type); + const TypeMeta dtype = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE( - dtype.id() != TypeIdentifier::uninitialized(), + dtype != ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); diff --git a/caffe2/python/pybind_state_dlpack.cc b/caffe2/python/pybind_state_dlpack.cc index 7b1ec2b8e141..a7204481224f 100644 --- a/caffe2/python/pybind_state_dlpack.cc +++ b/caffe2/python/pybind_state_dlpack.cc @@ -14,7 +14,7 @@ const DLDeviceType* CaffeToDLDeviceType(int device_type) { return it == dl_device_type_map.end() ? nullptr : &it->second; } -const DLDataType* CaffeToDLType(const TypeMeta& meta) { +const DLDataType* CaffeToDLType(const TypeMeta meta) { static std::map dl_type_map{ {TypeMeta::Id(), DLDataType{0, 8, 1}}, {TypeMeta::Id(), DLDataType{0, 16, 1}}, @@ -30,7 +30,7 @@ const DLDataType* CaffeToDLType(const TypeMeta& meta) { return it == dl_type_map.end() ? nullptr : &it->second; } -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type) { +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type) { try { if (dl_type.lanes != 1) { throw std::invalid_argument("invalid type"); diff --git a/caffe2/python/pybind_state_dlpack.h b/caffe2/python/pybind_state_dlpack.h index 54f3157e7634..bcdbc50a61d4 100644 --- a/caffe2/python/pybind_state_dlpack.h +++ b/caffe2/python/pybind_state_dlpack.h @@ -16,9 +16,9 @@ namespace py = pybind11; const DLDeviceType* CaffeToDLDeviceType(int device_type); -const DLDataType* CaffeToDLType(const TypeMeta& meta); +const DLDataType* CaffeToDLType(const TypeMeta meta); -const TypeMeta& DLTypeToCaffe(const DLDataType& dl_type); +const TypeMeta DLTypeToCaffe(const DLDataType& dl_type); // TODO: remove context template @@ -40,7 +40,7 @@ class DLPackWrapper { if (tensor->numel() <= 0) { tensor->Resize(0); } - if (tensor->dtype().id() == TypeIdentifier::uninitialized()) { + if (tensor->dtype() == ScalarType::Undefined) { // treat uninitialized tensor as float tensor tensor->template mutable_data(); } diff --git a/caffe2/python/pybind_state_ideep.cc b/caffe2/python/pybind_state_ideep.cc index 8d09b0aaa326..bbeaf524f055 100644 --- a/caffe2/python/pybind_state_ideep.cc +++ b/caffe2/python/pybind_state_ideep.cc @@ -97,7 +97,7 @@ class IDeepFetcher : public BlobFetcherBase { }; class IDeepFeeder : public BlobFeederBase { - itensor::data_type type_transform(const TypeMeta &meta) { + itensor::data_type type_transform(const TypeMeta meta) { if (meta == TypeMeta::Make()) return itensor::data_type::f32; else if (meta == TypeMeta::Make()) @@ -119,10 +119,10 @@ class IDeepFeeder : public BlobFeederBase { PyArrayObject *array = PyArray_GETCONTIGUOUS(original_array); auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); CAFFE_ENFORCE_NE( - meta.id(), - TypeIdentifier::uninitialized(), + meta, + ScalarType::Undefined, "This numpy data type is not supported: ", PyArray_TYPE(array), "."); @@ -172,7 +172,7 @@ class IDeepFeeder : public BlobFeederBase { auto g = MakeGuard([&]() { Py_XDECREF(array); }); const auto npy_type = PyArray_TYPE(array); - const TypeMeta &meta = NumpyTypeToCaffe(npy_type); + const TypeMeta meta = NumpyTypeToCaffe(npy_type); // TODO: if necessary, use dispatcher. if ((in_place && blob->IsType()) diff --git a/caffe2/python/session.py b/caffe2/python/session.py index de3b09931a30..fb2b57c4f5ee 100644 --- a/caffe2/python/session.py +++ b/caffe2/python/session.py @@ -192,7 +192,7 @@ def _compile_task_group(cls, task_group, setup_net_list=None): task = task_group.to_task() plan = core.Plan('task_group_plan') plan.AddStep(task.get_step()) - return (plan, task.output_list(), task.workspace_type) + return (plan, task.output_list(), task.workspace_type()) def _run_compiled(self, compiled): plan, output_list, workspace_type = compiled diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index 99983e84f097..0aa46ee2d4b3 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -335,7 +335,7 @@ def StringifyNetName(name): def GetNetName(net): if isinstance(net, basestring): return net - if type(net).__name__ == "Net": + if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference": return net.Name() if isinstance(net, caffe2_pb2.NetDef): return net.name diff --git a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc index c7e6804c1dcf..4a5a6e6b7ad0 100644 --- a/caffe2/quantization/server/fully_connected_dnnlowp_op.cc +++ b/caffe2/quantization/server/fully_connected_dnnlowp_op.cc @@ -190,6 +190,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { if (!dequantize_output_) { Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } DoNothing<> doNothingObj{}; if (quantize_channelwise_ || filter_qparams_[0].zero_point) { @@ -443,6 +446,9 @@ bool FullyConnectedDNNLowPOp::RunOnDevice() { #endif Y_int32_.resize(Y->size()); + if (Y_int32_.size() < Y_int32_.capacity() / 2) { + Y_int32_.shrink_to_fit(); + } for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t sum = 0; diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 63f5a34aa23b..7928d5e3de86 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -306,6 +306,7 @@ void PyTorchStreamWriter::setup(const string& file_name) { file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); valid("opening archive ", file_name.c_str()); + TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { file_stream_.write(static_cast(buf), nbytes); return !file_stream_ ? 0 : nbytes; diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index dbfd55e2d0d5..db02f7a8fb16 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -178,9 +178,6 @@ if(INTERN_BUILD_ATEN_OPS) --force_schema_registration --op_registration_whitelist ${OP_REGISTRATION_WHITELIST}) endif() - if(USE_VULKAN) - set(GEN_VULKAN_FLAGS --vulkan) - endif() set(GEN_COMMAND "${PYTHON_EXECUTABLE}" -m tools.codegen.gen diff --git a/codecov.yml b/codecov.yml index 7ed3d662bb39..525f85e01898 100644 --- a/codecov.yml +++ b/codecov.yml @@ -3,13 +3,18 @@ coverage: project: default: threshold: 1% +codecov: + notify: + after_n_builds: 2 comment: layout: "diff" behavior: once require_changes: true require_base: yes require_head: yes + after_n_builds: 2 branches: - "master" fixes: - "/opt/conda/lib/python3.8/site-packages/::project/" + - "C:/Users/circleci/project/build/win_tmp/build/::project/" diff --git a/docs/source/benchmark_utils.rst b/docs/source/benchmark_utils.rst new file mode 100644 index 000000000000..8e46d017cf1c --- /dev/null +++ b/docs/source/benchmark_utils.rst @@ -0,0 +1,12 @@ +.. role:: hidden + :class: hidden-section + +Benchmark Utils - torch.utils.benchmark +================================================== + +.. automodule:: torch.utils.benchmark +.. currentmodule:: torch.utils.benchmark + +.. autoclass:: Timer + :members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 54a82a07b1a9..7e11f617d43f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. torch.random sparse storage + torch.utils.benchmark torch.utils.bottleneck torch.utils.checkpoint torch.utils.cpp_extension diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 834b6a60ac93..14d3ca1767e9 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -14,3 +14,4 @@ Functions .. autofunction:: det .. autofunction:: norm +.. autofunction:: tensorsolve diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 7110631088d7..3bc806067870 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -349,6 +349,8 @@ view of a storage and defines numeric operations on it. .. automethod:: hypot_ .. automethod:: i0 .. automethod:: i0_ + .. automethod:: igamma + .. automethod:: igamma_ .. automethod:: ifft .. automethod:: index_add_ .. automethod:: index_add diff --git a/docs/source/torch.rst b/docs/source/torch.rst index b3c8410300c6..e36a3f944a7a 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -85,6 +85,7 @@ Indexing, Slicing, Joining, Mutating Ops cat chunk + column_stack dstack gather hstack @@ -94,6 +95,7 @@ Indexing, Slicing, Joining, Mutating Ops narrow nonzero reshape + row_stack split squeeze stack @@ -310,6 +312,7 @@ Pointwise Ops logit hypot i0 + igamma mul multiply mvlgamma diff --git a/mypy.ini b/mypy.ini index 1fd1ce884520..122fde2d6cf6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -77,9 +77,6 @@ ignore_errors = True [mypy-torch._tensor_str] ignore_errors = True -[mypy-torch.nn.modules.batchnorm] -ignore_errors = True - [mypy-torch.nn.modules.container] ignore_errors = True @@ -89,12 +86,6 @@ ignore_errors = True [mypy-torch.nn.modules.fold] ignore_errors = True -[mypy-torch.nn.modules.instancenorm] -ignore_errors = True - -[mypy-torch.nn.modules.linear] -ignore_errors = True - [mypy-torch.nn.modules.loss] ignore_errors = True @@ -113,9 +104,6 @@ ignore_errors = True [mypy-torch.nn.modules.rnn] ignore_errors = True -[mypy-torch.nn.modules.sparse] -ignore_errors = True - [mypy-torch.nn.parallel._functions] ignore_errors = True diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 8a8aa75541ac..bf72f2215972 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -7,6 +7,7 @@ #include using namespace torch::autograd; +using namespace torch::test; #define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b))) #define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b))) @@ -154,6 +155,39 @@ TEST(AutogradAPITests, RetainGrad) { ASSERT_VARIABLE_EQ(input * 18, input.grad()); } +TEST(AutogradAPITests, AnomalyMode) { + // Needs to have backtrace as warning and then throw an error + torch::autograd::DetectAnomalyGuard detect_anomaly; + { + WarningCapture warnings; + auto x = torch::tensor({5.0}, torch::requires_grad()); + auto y = x * x; + auto z = y * y; + y += 1; + ASSERT_THROWS_WITH(z.backward(), "inplace"); + ASSERT_TRUE( + warnings.str().find("Traceback of forward") != std::string::npos); + } + { + WarningCapture warnings; + // Double backward + auto x = torch::tensor({0.0}, torch::requires_grad()); + auto y = x.pow(1.5); + auto gr = + grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); + ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan"); + auto msgs = warnings.messages(); + ASSERT_EQ(msgs.size(), 2); + ASSERT_TRUE( + msgs[0].find("Traceback of forward call that caused the error") != + std::string::npos); + ASSERT_TRUE( + msgs[1].find( + "Traceback of forward call that induced the previous calculation") != + std::string::npos); + } +} + TEST(CustomAutogradTest, CustomFunction) { struct MyFunction : public Function { static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) { @@ -211,7 +245,7 @@ TEST(CustomAutogradTest, FunctionReturnsUndefined) { }; auto x = torch::ones(1, torch::requires_grad()); - + MyFunction::apply(x).backward(); ASSERT_FALSE(x.grad().defined()); diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp index 3993c63b1708..38ddfee5fdd2 100644 --- a/test/cpp/jit/test_autodiff.cpp +++ b/test/cpp/jit/test_autodiff.cpp @@ -81,6 +81,7 @@ variable_list grad( grad_outputs, true, false, + false, fmap(inputs, get_edge)); } diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index b262075a42aa..3f486951a559 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -46,6 +46,25 @@ TEST(LiteInterpreterTest, UpsampleNearest2d) { ASSERT_TRUE(resd.equal(refd)); } +TEST(LiteInterpreterTest, CheckAttrAccess) { + Module m("m"); + m.register_attribute("mobile_optimized", BoolType::get(), true); + + std::stringstream ss; + m._save_for_mobile(ss); + mobile::Module bc = _load_for_mobile(ss); + bool mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(mobile_optimized); + m.setattr("mobile_optimized", false); + ss = std::stringstream(); + m._save_for_mobile(ss); + bc = _load_for_mobile(ss); + mobile_optimized = bc.attr("mobile_optimized", false).toBool(); + + AT_ASSERT(!mobile_optimized); +} + TEST(LiteInterpreterTest, Add) { Module m("m"); m.register_parameter("foo", torch::ones({}), false); diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index ca4fb2e7620d..54265530eb12 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -65,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -1118,6 +1119,33 @@ TEST(RecordFunctionTest, Basic) { clearCallbacks(); } +TEST(RecordFunctionTest, OperatorNameOverload) { + std::set operator_names; + + at::addGlobalCallback(at::RecordFunctionCallback( + [&operator_names](const at::RecordFunction& fn) { + c10::optional op_name = + fn.operator_name(); + if (op_name.has_value()) { + operator_names.insert(c10::toString(*op_name)); + } else { + operator_names.insert("No Operator Name"); + } + }) + .scopes({at::RecordScope::FUNCTION})); + auto t = torch::randn({1, 2, 3}, at::kCPU); + t.set_requires_grad(false); + auto t2 = t.pow(2); + + at::clearCallbacks(); + EXPECT_TRUE(operator_names.count("No Operator Name") == 0) + << "Expected that all traced operators had an associated OperatorName object"; + EXPECT_TRUE(operator_names.count("aten::randn") == 1) + << "Expected aten::randn to have been called and recorded, but it was not"; + EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1) + << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not"; +} + class TestThreadLocalDebugInfo : public c10::DebugInfoBase { public: int getModelId() const { diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index a433f08691ac..29744c3fb7ec 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -551,9 +551,10 @@ void testKernelSumOneAxis() { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: int v = 0 -# CHECK: int v_1 = 0 -# CHECK: input1)IR"; +# CHECK: for (int v = 0; v < +# CHECK-NEXT: sum +# CHECK-NEXT: for (int v_1 = 0; v_1 < +# CHECK-NEXT: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); std::vector stack = fmap(inputs); @@ -612,7 +613,7 @@ void testKernelSumMultipleAxes() { # CHECK: int v_1 = 0 # CHECK: int v_2 = 0 # CHECK: int v_3 = 0 -# CHECK: input1)IR"; +# CHECK: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); std::vector stack = fmap(inputs); @@ -641,20 +642,17 @@ void testKernelSoftmax2D() { const std::string& verification_template = R"IR( - # CHECK: for (int i0 = 0; i0 < 5 - # CHECK-NEXT: for (int i1 = 0; i1 < 3 - # CHECK-NEXT: input1 - # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i0_2 = 0; i0_2 < 5 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK: for (int i0_1 = 0; i0_1 < 5 + # CHECK-NEXT: for (int i1_1 = 0; i1_1 < 3 # CHECK-NEXT: aten_softmax_exp - # CHECK: for (int i${other_dim}_3 = 0; i${other_dim}_3 < ${other_dim_size} - # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK: for (int i${other_dim}_2 = 0; i${other_dim}_2 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_2 = 0; i${softmax_dim}_2 < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_4 = 0; i0_4 < 5 - # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 3 + # CHECK: for (int i0_3 = 0; i0_3 < 5 + # CHECK-NEXT: for (int i1_3 = 0; i1_3 < 3 # CHECK-NEXT: aten_softmax)IR"; for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { @@ -705,25 +703,21 @@ void testKernelSoftmax3D() { const std::string& verification_template = R"IR( - # CHECK: for (int i0 = 0; i0 < 3 - # CHECK-NEXT: for (int i1 = 0; i1 < 4 - # CHECK-NEXT: for (int i2 = 0; i2 < 5 - # CHECK-NEXT: input1 - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i0_2 = 0; i0_2 < 3 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 + # CHECK: for (int i0_1 = 0; i0_1 < 3 + # CHECK-NEXT: for (int i1_1 = 0; i1_1 < 4 + # CHECK-NEXT: for (int i2_1 = 0; i2_1 < 5 # CHECK-NEXT: aten_softmax_exp - # CHECK: for (int i${dim1}_3 = 0; i${dim1}_3 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_3 = 0; i${dim2}_3 < ${dim2_size} - # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK: for (int i${dim1}_2 = 0; i${dim1}_2 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_2 = 0; i${dim2}_2 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_2 = 0; i${softmax_dim}_2 < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_4 = 0; i0_4 < 3 - # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 4 - # CHECK-NEXT: for (int i2_4 = 0; i2_4 < 5 + # CHECK: for (int i0_3 = 0; i0_3 < 3 + # CHECK-NEXT: for (int i1_3 = 0; i1_3 < 4 + # CHECK-NEXT: for (int i2_3 = 0; i2_3 < 5 # CHECK-NEXT: aten_softmax)IR"; for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { @@ -782,30 +776,25 @@ void testKernelSoftmax4D() { const std::string& verification_template = R"IR( - # CHECK: for (int i0 = 0; i0 < 2 - # CHECK-NEXT: for (int i1 = 0; i1 < 3 - # CHECK-NEXT: for (int i2 = 0; i2 < 2 - # CHECK-NEXT: for (int i3 = 0; i3 < 3 - # CHECK-NEXT: input1 - # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} - # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} - # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_max - # CHECK: for (int i0_2 = 0; i0_2 < 2 - # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 - # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 - # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 + # CHECK: for (int i0_1 = 0; i0_1 < 2 + # CHECK-NEXT: for (int i1_1 = 0; i1_1 < 3 + # CHECK-NEXT: for (int i2_1 = 0; i2_1 < 2 + # CHECK-NEXT: for (int i3_1 = 0; i3_1 < 3 # CHECK-NEXT: aten_softmax_exp - # CHECK: for (int i${dim1}_3 = 0; i${dim1}_3 < ${dim1_size} - # CHECK-NEXT: for (int i${dim2}_3 = 0; i${dim2}_3 < ${dim2_size} - # CHECK-NEXT: for (int i${dim3}_3 = 0; i${dim3}_3 < ${dim3_size} - # CHECK: for (int i${softmax_dim}_3 = 0; i${softmax_dim}_3 < ${softmax_dim_size} + # CHECK: for (int i${dim1}_2 = 0; i${dim1}_2 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_2 = 0; i${dim2}_2 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_2 = 0; i${dim3}_2 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_2 = 0; i${softmax_dim}_2 < ${softmax_dim_size} # CHECK-NEXT: aten_softmax_sum - # CHECK: for (int i0_4 = 0; i0_4 < 2 - # CHECK-NEXT: for (int i1_4 = 0; i1_4 < 3 - # CHECK-NEXT: for (int i2_4 = 0; i2_4 < 2 - # CHECK-NEXT: for (int i3_4 = 0; i3_4 < 3 + # CHECK: for (int i0_3 = 0; i0_3 < 2 + # CHECK-NEXT: for (int i1_3 = 0; i1_3 < 3 + # CHECK-NEXT: for (int i2_3 = 0; i2_3 < 2 + # CHECK-NEXT: for (int i3_3 = 0; i3_3 < 3 # CHECK-NEXT: aten_softmax)IR"; for (int softmax_dim = 0; softmax_dim < a.dim(); ++softmax_dim) { @@ -853,5 +842,88 @@ void testKernelSoftmax4D() { } } +void testKernelInlineProducerIntoReduction() { + KernelScope kernel_scope; + + // Inline producer (mul) into reduction (sum). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=7]() + %4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + Stmt* s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have only one loop in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int v = 0; v < 5; + # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK-NEXT: sum + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kDouble); + ASSERT_TRUE(at::allclose(o, ref)); +} + +void testKernelInlineReductionIntoConsumer() { + KernelScope kernel_scope; + + // Inline producer (mul %2) into reduction (sum %4) but DO NOT + // inline the reduction into consumer (mul %4). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=6]() + %4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3) + %5 : Float(5, 3, strides=[3, 1]) = aten::mul(%2, %4) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + Stmt* s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have two loops in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int v = 0; v < 5; + # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK-NEXT: sum + # CHECK: for (int v_2 = 0; v_2 < 5; + # CHECK-NEXT: for (int v_3 = 0; v_3 < 3; + # CHECK-NEXT: aten_mul + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kFloat) * (a * b); + ASSERT_TRUE(at::allclose(o, ref)); +} + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index 0555c310cc5b..d31c6e9bca30 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1337,8 +1337,8 @@ void testReduceInlineReduction() { } LoopNest l1({y}); - ASSERT_THROWS_WITH( - l1.computeInline(x->buf()), "cannot inline a reduction computation"); + // Cannot inline a reduction computation + ASSERT_FALSE(l1.computeInline(x->buf())); } void testReduceInlineConsumer() { diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h index 5f6011dede9d..6f7b50790477 100644 --- a/test/cpp/tensorexpr/tests.h +++ b/test/cpp/tensorexpr/tests.h @@ -367,6 +367,8 @@ namespace jit { _(KernelSoftmax2D) \ _(KernelSoftmax3D) \ _(KernelSoftmax4D) \ + _(KernelInlineProducerIntoReduction) \ + _(KernelInlineReductionIntoConsumer) \ _(FuserPass_1) \ _(FuserPass_2) \ _(FuserPass_3) \ diff --git a/test/distributed/_pipeline/sync/test_worker.py b/test/distributed/_pipeline/sync/test_worker.py index 0247a71ba4a8..fb306b52fad9 100644 --- a/test/distributed/_pipeline/sync/test_worker.py +++ b/test/distributed/_pipeline/sync/test_worker.py @@ -13,6 +13,7 @@ from torch.distributed._pipeline.sync.microbatch import Batch from torch.distributed._pipeline.sync.stream import CPUStream from torch.distributed._pipeline.sync.worker import Task, spawn_workers +from torch.testing._internal.common_utils import TEST_WITH_TSAN class fake_device: @@ -24,6 +25,7 @@ class fake_device: index = None +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") def test_join_running_workers(): count = 0 @@ -47,6 +49,7 @@ def call_in_worker(i, f): assert count == 10 +@pytest.mark.skipif(TEST_WITH_TSAN, reason="False positive in TSAN") def test_join_running_workers_with_exception(): class ExpectedException(Exception): pass diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index 89330ddbd2d9..9f61ec77a1f6 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -5,7 +5,9 @@ import torch import torch._C +from torch.testing import FileCheck from pathlib import Path + from torch.testing._internal.common_utils import ( IS_FBCODE, IS_MACOS, @@ -34,6 +36,13 @@ def to_test_backend_multi(module, method_compile_spec): return torch._C._jit_to_backend("test_backend", module, method_compile_spec) +def to_test_backend_selective(module, method_compile_spec, submodules): + def _to_test_backend(module): + return to_test_backend(module, method_compile_spec) + + return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) + + class BasicModule(torch.nn.Module): """ A simple Module used to test to_backend lowering machinery. @@ -81,9 +90,9 @@ def check_function(self, function_name, input): backend_method = self.lowered_module.__getattr__(function_name) # Run methods. - python_output = python_method(input, input) - jit_output = jit_method(input, input) - backend_output = backend_method(input, input) + python_output = python_method(*input) + jit_output = jit_method(*input) + backend_output = backend_method(*input) # The answers returned by Python, JIT and to_backend should all match. self.assertEqual(python_output, backend_output) @@ -95,6 +104,24 @@ def save_load(self): """ self.lowered_module = self.getExportImportCopy(self.lowered_module) + def test_execution(self): + """ + Stub for correctness tests. + """ + pass + + def test_save_load(self): + """ + Stub for serialization tests. + """ + pass + + def test_errors(self): + """ + Stub for testing error checking. + """ + pass + class BasicModuleTest(JitBackendTestCase): """ @@ -116,9 +143,9 @@ def test_execution(self): input = torch.randn(5) # Test all three module methods. - self.check_function("accum", input) - self.check_function("sub_accum", input) - self.check_function("forward", input) + self.check_function("accum", (input, input)) + self.check_function("sub_accum", (input, input)) + self.check_function("forward", (input, input)) @skipIfRocm def test_save_load(self): @@ -166,8 +193,12 @@ def setUp(self): self.module = NestedModuleTest.NestedModule(BasicModule()) # Both modules in self.scripted_module are ScriptModules. self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule())) + + # First, script another instance of NestedModule with share_types=False so that it can be + # selectively lowered without modifying the type of self.scripted_module. lowered_module = to_test_backend_multi( - self.scripted_module, {"forward": {"": ""}} + torch.jit.script(BasicModule()), + {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, ) # self.lowered_module is a ScriptModule, but its submodule is a lowered module. self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module)) @@ -177,7 +208,7 @@ def test_execution(self): input = torch.randn(5) # Test forward. - self.check_function("forward", input) + self.check_function("forward", (input, input)) def test_save_load(self): # Lowered module should produce the same outputs. @@ -190,6 +221,161 @@ def test_save_load(self): self.test_execution() +class SelectiveLoweringTest(JitBackendTestCase): + """ + Tests for the selective lowering API. + """ + class OuterModule(torch.nn.Module): + def __init__(self, sub1, sub2, other): + super().__init__() + self.sub1 = sub1 + self.sub2 = sub2 + self.other = other + + def forward(self, x, y): + # Call the module that will be lowered directly to test + # type remapping in modules that are not its parent. + a, b = self.sub1.submodule.forward(x, y) + c, d = self.sub2.forward(x, y) + e, f = self.other.forward(x, y) + return a + c + e, b + d + f + + class MiddleModule(torch.nn.Module): + def __init__(self, submodule): + super().__init__() + self.submodule = submodule + + def forward(self, x, y): + return self.submodule.forward(x, y) + + def setUp(self): + super().setUp() + OuterModule = SelectiveLoweringTest.OuterModule + MiddleModule = SelectiveLoweringTest.MiddleModule + + def script_without_type_sharing(mod): + return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False) + # Create Python, JIT and backend versions of a hierarchy that looks like this: + # --------- OuterModule -------- + # | | | + # MiddleModule MiddleModule MiddleModule + # | | | + # BasicModule BasicModule BasicModule + # + # Two BasicModules will be lowered and the third will not. + self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule( + BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) + self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [ + "sub1.submodule", "sub2.submodule"]) + + def test_execution(self): + input = torch.randn(5) + self.check_function("forward", (input, input)) + + self.test_selective_lowering_type_remap() + + def test_save_load(self): + self.test_execution() + self.save_load() + self.test_execution() + + self.test_selective_lowering_type_remap() + + def test_selective_lowering_type_remap(self): + """ + Check that type remapping and replacement occurred during selective lowering. + """ + # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it + # calling the lowered module directly. + FileCheck() \ + .check("OuterModule") \ + .check("BasicModule") \ + .run(self.scripted_module.graph) + FileCheck() \ + .check("OuterModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .run(self.lowered_module.graph) + + # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub1.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub1.graph) + + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.sub2.graph) + FileCheck() \ + .check("MiddleModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .run(self.lowered_module.sub2.graph) + + # Check that self.lowered_module.sub1/sub2.submodule were lowered. Its graph should mention + # __torch__.torch.classes.__backends__.test_backend, the TorchBind class for executing functions + # on the test JIT backend. + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub1.submodule.graph) + + FileCheck() \ + .check("test_backendLoweredModule") \ + .check_not("BasicModule") \ + .check("__torch__.torch.classes.__backends__.test_backend") \ + .run(self.lowered_module.sub2.submodule.graph) + + # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. + FileCheck() \ + .check("MiddleModule") \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.graph) + FileCheck() \ + .check("BasicModule") \ + .check_not("__torch__.torch.classes.__backends__.test_backend") \ + .check_not("test_backendLoweredModule") \ + .run(self.scripted_module.other.submodule.graph) + + def test_errors(self): + """ + Check errors associated with selective lowering. + """ + # Check error messages thrown when attempting to lower something that is not a ScriptModule. + with self.assertRaisesRegex(RuntimeError, r"Object .* is not a ScriptModule"): + to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) + + MiddleModule = SelectiveLoweringTest.MiddleModule + mod = MiddleModule(BasicModule()) + mod.new_attr = 3 + + with self.assertRaisesRegex(RuntimeError, r"Attribute named new_attr is not a Module"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"]) + + # Check error message thrown when module hierarchy doesn't have unique types. + OuterModule = SelectiveLoweringTest.OuterModule + mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) + + with self.assertRaisesRegex(RuntimeError, r"Selective lowering is only supported for module hierarchies with unique types"): + to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]) + + class TestBackends(JitTestCase): """ This class wraps and invokes all subclasses of JitBackendTestCase so that each one @@ -200,19 +386,27 @@ def __init__(self, name): super().__init__(name) self.basic_module_test = BasicModuleTest(name) self.nested_module_test = NestedModuleTest(name) + self.selective_lowering_test = SelectiveLoweringTest(name) def setUp(self): super().setUp() if not TEST_WITH_ROCM: self.basic_module_test.setUp() self.nested_module_test.setUp() + self.selective_lowering_test.setUp() @skipIfRocm def test_execution(self): self.basic_module_test.test_execution() self.nested_module_test.test_execution() + self.selective_lowering_test.test_execution() @skipIfRocm def test_save_load(self): self.basic_module_test.test_save_load() self.nested_module_test.test_save_load() + self.selective_lowering_test.test_save_load() + + @skipIfRocm + def test_errors(self): + self.selective_lowering_test.test_errors() diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 5f16b0229c2d..9c59b00a04cd 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -7,6 +7,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.jit._recursive import wrap_cpp_module +from typing import Any import io @@ -1222,3 +1223,35 @@ def forward(self, cond: bool): mod_eager = Mod() self.assertEqual(mod_eager(True), frozen_mod(True)) self.assertEqual(mod_eager(False), frozen_mod(False)) + + def test_freeze_module_with_non_static_module_dict_index(self): + """ + Test that a Module contained a non-static ModuleDict index + cannot be frozen. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + # Test annotation of submodule. + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = torch.jit.script(Mod()) + m.eval() + with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"): + mf = torch._C._freeze_module(m._c) diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index b53bf10a70c2..e261124bedb5 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -1,7 +1,7 @@ import os import sys -from typing import List +from typing import Any, List, Tuple from collections import OrderedDict import torch import torch.nn as nn @@ -428,3 +428,64 @@ def forward(self, inputs): m = MyModule() self.checkModule(m, [torch.randn(2, 2)]) + + def test_typed_module_dict(self): + """ + Test that a type annotation can be provided for a ModuleDict that allows + non-static indexing. + """ + @torch.jit.interface + class ModuleInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + pass + + class ImplementsInterface(torch.nn.Module): + def forward(self, inp: Any) -> Any: + if isinstance(inp, torch.Tensor): + return torch.max(inp, dim=0) + + return inp + + class DoesNotImplementInterface(torch.nn.Module): + def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.max(inp, dim=0) + + # Test annotation of submodule. + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + value: ModuleInterface = self.d[key] + return value.forward(x) + + m = Mod() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test annotation of self. + class ModDict(torch.nn.ModuleDict): + def __init__(self): + super().__init__({"module": ImplementsInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self[key] + return submodule.forward(x) + + m = ModDict() + self.checkModule(m, (torch.randn(2, 2), "module")) + + # Test error message thrown when annotated attribute does not comply with the + # annotation. + class ModWithWrongAnnotation(torch.nn.ModuleDict): + def __init__(self): + super().__init__() + self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()}) + + def forward(self, x: torch.Tensor, key: str) -> Any: + submodule: ModuleInterface = self.d[key] + return submodule.forward(x) + + with self.assertRaisesRegex(RuntimeError, r"Attribute module is not of annotated type"): + torch.jit.script(ModWithWrongAnnotation()) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 178db8357e8f..23751c4fd92b 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -938,3 +938,12 @@ def forward(self, a): x = torch.tensor([1., 2., 3., 4.]) self.assertTrue(torch.equal(m(x), m2(x))) + + def test_save_nonexit_file(self): + class Foo(torch.nn.Module): + def forward(self, x): + return 2 * x + + script_module = torch.jit.script(Foo()) + with self.assertRaises(RuntimeError): + script_module.save("NonExist/path/test.pt") diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index f08f772aa8e1..38ddb094794e 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -430,6 +430,16 @@ def forward(self, input): m1 = torch.randn(3, 4, 5, 6, 7) self.run_test(MyModel(), m1) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # Need type inference + def test_index_mask_nd(self): + class MyModel(torch.nn.Module): + def forward(self, input): + return input[input > 0] + + m1 = torch.randn(3, 4, 5, 6, 7) + self.run_test(MyModel(), m1) + @disableScriptTest() def test_dict(self): class MyModel(torch.nn.Module): @@ -452,6 +462,42 @@ def forward(self, x_in): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x,)) + def test_none_as_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y is not None: + return x + y + return x + + x = torch.randn(2, 3) + self.run_test(Model(), (x, None)) + + def test_none_as_tuple_input(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if y[0] is not None: + return x + y[0] + if y[1] is not None: + return x + y[1] + return x + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + self.run_test(Model(), (x, (None, y))) + + def test_none_as_named_input(self): + class Model(torch.nn.Module): + def forward(self, x, y=None, z=None): + if y is not None: + return x + y + if z is not None: + return x + z + return x + + x = torch.randn(2, 3) + z = torch.randn(2, 3) + self.run_test(Model(), (x, None, z)) + @skipIfUnsupportedMinOpsetVersion(9) def test_cste_script(self): class MyModel(torch.jit.ScriptModule): diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index f2911642cf3a..153fe74ba913 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -3,16 +3,20 @@ import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq import torch.multiprocessing as mp # graph mode quantization based on fx -from torch.quantization import ( - QuantType, - quant_type_to_str, +from torch.quantization.quantize_fx import ( prepare_fx, convert_fx, prepare_qat_fx, +) + +from torch.quantization import ( + QuantType, + quant_type_to_str, default_qconfig, default_dynamic_qconfig, default_dynamic_quant_observer, @@ -46,9 +50,6 @@ from torch.testing._internal.common_quantization import NodeSpec as ns -from torch.testing._internal.common_quantization import ( - test_only_eval_fn, -) from torch.testing import FileCheck import copy @@ -57,6 +58,124 @@ import unittest import io +class TestFuseFx(QuantizationTestCase): + def test_fuse_conv_bn_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.conv1d2 = nn.Conv1d(1, 1, 1) + self.conv2d2 = nn.Conv2d(1, 1, 1) + self.conv3d2 = nn.Conv3d(1, 1, 1) + self.bn1d2 = nn.BatchNorm1d(1) + self.bn2d2 = nn.BatchNorm2d(1) + self.bn3d2 = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.bn1d(x) + x = self.conv2d(x) + x = self.bn2d(x) + x = self.conv3d(x) + x = self.bn3d(x) + x = self.conv1d2(x) + x = self.bn1d2(x) + x = self.relu(x) + x = self.conv2d2(x) + x = self.bn2d2(x) + x = self.relu(x) + x = self.conv3d2(x) + x = self.bn3d2(x) + x = self.relu(x) + return x + + # test train mode + m = M().train() + # currently we don't check if the module are configured with qconfig before fusion + # TODO: if we decide to do that in the future, this test needs to + # be updated + # train mode fuse_fx is called in prepare_qat_fx + m = prepare_qat_fx(m, {}) + expected_nodes = [ + ns.call_module(nni.ConvBn2d), + ns.call_module(nni.ConvBn3d), + ns.call_module(nni.ConvBnReLU2d), + ns.call_module(nni.ConvBnReLU3d), + ] + # ConvBnRelu1d is not fused + expected_occurrence = { + ns.call_module(nn.ReLU): 1 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + # test eval mode + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + # fuse_fx is a top level api and only supports eval mode + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nn.Conv2d), + ns.call_module(nn.Conv3d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ] + # ConvBnRelu1d is not fused + expected_occurrence = { + ns.call_module(nn.ReLU): 1 + } + self.checkGraphModuleNodes( + m, + expected_node_list=expected_nodes, + expected_node_occurrence=expected_occurrence) + + def test_fuse_module_relu(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(1, 1, 1) + self.conv2d = nn.Conv2d(1, 1, 1) + self.conv3d = nn.Conv3d(1, 1, 1) + self.bn1d = nn.BatchNorm1d(1) + self.bn2d = nn.BatchNorm2d(1) + self.bn3d = nn.BatchNorm3d(1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1d(x) + x = self.relu(x) + x = self.conv2d(x) + x = self.relu(x) + x = self.conv3d(x) + x = self.relu(x) + x = self.bn1d(x) + x = self.relu(x) + x = self.bn2d(x) + x = self.relu(x) + x = self.bn3d(x) + x = self.relu(x) + return x + + m = M().eval() + from torch.quantization.quantize_fx import fuse_fx + m = fuse_fx(m) + expected_nodes = [ + ns.call_module(nni.ConvReLU1d), + ns.call_module(nni.ConvReLU2d), + ns.call_module(nni.ConvReLU3d), + ns.call_module(nni.BNReLU2d), + ns.call_module(nni.BNReLU3d), + ] + self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) + @skipIfNoFBGEMM class TestQuantizeFx(QuantizationTestCase): def _get_conv_linear_test_cases(self): @@ -268,30 +387,6 @@ def forward(self, x): model_device = next(iter(model_devices)) self.assertEqual(model_device, device) - @skipIfNoFBGEMM - def test_inplace_option(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - return self.conv(x) - - model = M().eval() - qconfig_dict = {'': default_qconfig} - prepared = prepare_fx(model, qconfig_dict) - test_only_eval_fn(model, self.img_data_2d) - non_inplace_model = convert_fx(prepared, inplace=True) - - prepared = prepare_fx(model, qconfig_dict) - test_only_eval_fn(model, self.img_data_2d) - inplace_model = convert_fx(prepared, inplace=True) - - non_inplace_res = non_inplace_model(self.img_data_2d[0][0]) - inplace_res = inplace_model(self.img_data_2d[0][0]) - self.assertEqual(non_inplace_res, inplace_res) - @skipIfNoFBGEMM def test_dict_output(self): """ Make sure quantization runs for models with dictionary output @@ -858,13 +953,13 @@ def forward(self, x): print(m.__dict__.keys()) m.eval() qconfig_dict = {'': torch.quantization.default_qconfig} - prepared = torch.quantization.prepare_fx(m, qconfig_dict) + prepared = prepare_fx(m, qconfig_dict) # calibrate prepared(torch.randn(4, 1, 4, 4)) # copy prepared_copy = copy.deepcopy(prepared) # quantize, should run with no errors - quantized = torch.quantization.convert_fx(prepared_copy) + quantized = convert_fx(prepared_copy) @skipIfNoFBGEMM @@ -1683,7 +1778,7 @@ def forward(self, x): ref_m = convert(ref_m) self.assertEqual(m(data), ref_m(data)) - def test_qembedding_module(self): + def test_embedding(self): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -1695,15 +1790,23 @@ def forward(self, indices): model = M().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) quantized_node = ns.call_module(nnq.Embedding) - self.checkGraphModeFxOp( - model, - [[indices]], - QuantType.DYNAMIC, - quantized_node, - custom_qconfig=float_qparams_dynamic_qconfig - ) - - def test_qembedding_bag_module(self): + configs = [ + (float_qparams_dynamic_qconfig, ns.call_module(nnq.Embedding)), + (None, ns.call_module(nn.Embedding)) + ] + + for qconfig, node in configs: + qconfig_dict = {"": qconfig} + m = prepare_fx(model, qconfig_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence={ + ns.call_module(torch.quantization.MinMaxObserver): 0 + }) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node=node) + # make sure it runs + m(indices) + + def test_embedding_bag(self): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -1712,13 +1815,13 @@ def __init__(self): def forward(self, indices, offsets): return self.emb(indices, offsets) - model = M().eval() indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) offsets = torch.tensor([0, 19, 20, 28, 28, 32]) quantized_node = ns.call_module(nnq.EmbeddingBag) inputs = (indices, offsets) for dtype in [torch.quint8, torch.quint4x2]: + model = M().eval() float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) @@ -1732,6 +1835,17 @@ def forward(self, indices, offsets): custom_qconfig=float_qparams_qconfig ) + # check it works in None qconfig + qconfig_dict = {"": None} + m = M().eval() + m = prepare_fx(model, qconfig_dict) + self.checkGraphModuleNodes(m, expected_node_occurrence={ + ns.call_module(torch.quantization.MinMaxObserver): 0 + }) + m = convert_fx(m) + self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) + # make sure it runs + m(*inputs) class TestQuantizeFxModels(QuantizationTestCase): def _test_model_impl( diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 36f317529285..96568662c052 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -91,9 +91,9 @@ def pool_output_shape(input_size, kernel_size, padding, stride, output_size = ( (input_size + 2 * padding - dilation * (kernel_size - 1) - 1 + (stride - 1 if ceiling_mode else 0)) // stride + 1) - if (padding > 0 and + if (ceiling_mode and ((output_size - 1) * stride >= input_size + padding)): - output_size += 1 + output_size -= 1 return output_size """ @@ -3611,6 +3611,122 @@ def test_qconv_transpose2d( Y_q = qconv_op(X_q) self.assertEqual(Y_q_ref, Y_q) + """Tests the correctness of quantized convolution op.""" + @given(batch_size=st.integers(1, 3), + input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + time=st.integers(2, 5), + height=st.integers(10, 16), + width=st.integers(7, 14), + output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), + groups=st.integers(1, 3), + kernel_t=st.integers(1, 7), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_t=st.integers(1, 2), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_t=st.integers(0, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + o_pad_t=st.integers(0, 2), + o_pad_h=st.integers(0, 2), + o_pad_w=st.integers(0, 2), + dilation=st.integers(1, 2), + X_scale=st.floats(1.2, 1.6), + X_zero_point=st.integers(0, 4), + W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), + W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), + Y_scale=st.floats(4.2, 5.6), + Y_zero_point=st.integers(0, 4), + use_bias=st.booleans()) + @override_qengines + def test_qconv_transpose3d( + self, + batch_size, + input_channels_per_group, + time, + height, + width, + output_channels_per_group, + groups, + kernel_t, + kernel_h, + kernel_w, + stride_t, + stride_h, + stride_w, + pad_t, + pad_h, + pad_w, + o_pad_t, + o_pad_h, + o_pad_w, + dilation, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + Y_scale, + Y_zero_point, + use_bias): + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + assume(o_pad_t < stride_t or o_pad_t < dilation) + assume(o_pad_h < stride_h or o_pad_h < dilation) + assume(o_pad_w < stride_w or o_pad_w < dilation) + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_t, kernel_h, kernel_w) + strides = (stride_t, stride_h, stride_w) + pads = (pad_t, pad_h, pad_w) + o_pads = (o_pad_t, o_pad_h, o_pad_w) + dilations = (dilation, dilation, dilation) + + qconv = torch.ops.quantized.conv_transpose3d + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + conv_op = torch.nn.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + X_q, W_q, bias_float = self._test_qconv_impl( + qconv, qconv_prepack, conv_op, batch_size, + input_channels_per_group, (time, height, width), + output_channels_per_group, groups, kernels, strides, pads, o_pads, + dilations, X_scale, X_zero_point, W_scale, W_zero_point, + Y_scale, Y_zero_point, use_bias, use_relu=False, + use_channelwise=False, use_transpose=True) + + # Test the module implementation + qconv_op = torch.nn.quantized.ConvTranspose3d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=kernels, + stride=strides, + padding=pads, + output_padding=o_pads, + groups=groups, + dilation=dilations, + bias=use_bias + ) + qconv_op.scale = Y_scale + qconv_op.zero_point = Y_zero_point + qconv_op.set_weight_bias(W_q, bias_float) + + Y_dq_ref = conv_op(X_q.dequantize()) + Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, + zero_point=Y_zero_point, + dtype=torch.quint8) + Y_q = qconv_op(X_q) + self.assertEqual(Y_q_ref, Y_q) + @given( inputs=hu.tensor_conv( spatial_dim=1, batch_size_range=(1, 3), @@ -3863,22 +3979,26 @@ def test_qconv3d( stride_w=st.integers(1, 2), pad_d=st.integers(1, 2), pad_h=st.integers(1, 2), pad_w=st.integers(1, 2), - channelwise=st.booleans(), - qengine=st.sampled_from(("fbgemm",))) + o_pad=st.integers(0, 2), + channelwise=st.booleans()) + @override_qengines def test_qconv3d_unpack( - self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, - channelwise, qengine + self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, o_pad, + channelwise ): - if qengine not in supported_qengines: - return - - with override_quantized_engine(qengine): - qconv3d_prepack = torch.ops.quantized.conv3d_prepack - qconv3d_unpack = torch.ops.quantized.conv3d_unpack - self._test_qconv_unpack_impl( - qconv3d_prepack, qconv3d_unpack, inputs, - (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), None, - channelwise) + if qengine_is_qnnpack(): + return # QNNPACK doesn't support this + transposed = inputs[-1] + if transposed: + qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack + qconv_unpack = torch.ops.quantized.conv_transpose3d_unpack + else: + qconv_prepack = torch.ops.quantized.conv3d_prepack + qconv_unpack = torch.ops.quantized.conv3d_unpack + self._test_qconv_unpack_impl( + qconv_prepack, qconv_unpack, inputs, + (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), + channelwise) class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index a16b01d36d46..cd722d59d2a2 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -736,7 +736,7 @@ def test_histogram_observer_same_inputs(self): self.assertEqual(myobs.max_val, 8.0) self.assertEqual(myobs.histogram, [2., 3., 3.]) - @given(N=st.sampled_from([10, 1000, 10**6]), + @given(N=st.sampled_from([10, 1000]), bins=st.sampled_from([256, 512, 1024, 2048]), dtype=st.sampled_from([torch.qint8, torch.quint8]), qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]), diff --git a/test/run_test.py b/test/run_test.py index ad4603e809f2..0fa84c00044c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -313,6 +313,11 @@ def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unit if extra_unittest_args: assert isinstance(extra_unittest_args, list) unittest_args.extend(extra_unittest_args) + + # If using pytest, replace -f with equivalent -x + if options.pytest: + unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args] + # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. argv = [test_module + '.py'] + unittest_args diff --git a/test/test_autograd.py b/test/test_autograd.py index eec016cc9623..3c4db8267e87 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -866,6 +866,64 @@ def call_backwards(): torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)]) self.assertRaises(RuntimeError, call_backwards) + def test_backward_with_inputs(self): + x = torch.randn(2, 2, requires_grad=True) + y = torch.randn(2, 2, requires_grad=True) + + def fn(): + return x ** 2 + y * x + y ** 2 + + gradient = torch.ones(2, 2) + x_grad_expected = 2 * x + y + y_grad_expected = x + 2 * y + + @torch.no_grad() + def reset_grad(): + x.grad.zero_() + y.grad.zero_() + + torch.autograd.backward(fn(), gradient, inputs=[x, y]) + self.assertEqual(x.grad, x_grad_expected) + self.assertEqual(y.grad, y_grad_expected) + + reset_grad() + torch.autograd.backward(fn(), gradient, inputs=[x]) + self.assertEqual(x.grad, x_grad_expected) + self.assertEqual(y.grad, torch.zeros(2, 2)) + + reset_grad() + torch.autograd.backward(fn(), gradient, inputs=[y]) + self.assertEqual(y.grad, y_grad_expected) + self.assertEqual(x.grad, torch.zeros(2, 2)) + + reset_grad() + self.assertRaisesRegex(RuntimeError, 'cannot be empty', + lambda: torch.autograd.backward(fn(), gradient, inputs=[])) + + def test_backward_with_nonleaf_inputs(self): + x = torch.randn(2, 2, requires_grad=True) + x_nonleaf = x * 1 + y = torch.randn(2, 2, requires_grad=True) + z = torch.randn(2, 2, requires_grad=True) + + out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2 + + out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y]) + x_grad_expected = 2 * x + y + y_grad_expected = x + 2 * y + + self.assertEqual(y.grad, y_grad_expected) + self.assertEqual(x.grad, x_grad_expected) + + self.assertRaisesRegex(RuntimeError, 'not a leaf Tensor', + lambda: out.backward(torch.ones(2, 2), create_graph=True, inputs=[x, y, x_nonleaf])) + + # backward doesn't have an allow_unused flag, so the behavior of backward + # when variable is not part of the graph is as if allow_used were true + # x.grad will simply be None. + out.backward(torch.ones(2, 2), create_graph=True, inputs=[z]) + self.assertIsNone(z.grad) + def test_dependent_backward(self): x = torch.randn(10, requires_grad=True) y = x ** 2 @@ -2881,6 +2939,14 @@ def test_pow_scalar_base(self): a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_() gradcheck(lambda a: torch.pow(2, a), (a,)) + def test_igamma(self): + # 1e-3 offset to avoid zeros + # NOTE: derivative for s is not implemented + s = (torch.rand(100, dtype=torch.double) + 1e-3) + x = (torch.rand(100, dtype=torch.double) + 1e-3).requires_grad_() + gradcheck(torch.igamma, (s, x)) + gradgradcheck(torch.igamma, (s, x)) + @skipIfNoLapack def test_pinverse(self): # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? @@ -4918,7 +4984,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, # and only run for floating point separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'div', 'log', - 'log10', 'log1p', 'log2', 'pow', 'tan', 'reciprocal', 'rsqrt', '__rdiv__'] + 'log10', 'log1p', 'log2', 'pow', 'tan', 'reciprocal', 'rsqrt', + '__rdiv__', 'add', 'sub'] # NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly # for non-holomorphic functions @@ -4930,7 +4997,8 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', - 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub'] + separate_complex_tests + 'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub', + 'exp'] + separate_complex_tests # this list corresponds to cases that are not currently implemented skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] diff --git a/test/test_cuda.py b/test/test_cuda.py index d20af4082d03..651ae88da25f 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -504,22 +504,35 @@ def _test_copy_non_blocking(a, b): y = torch.ones(10000000, dtype=torch.uint8).cuda() _test_copy_non_blocking(x, y) - @unittest.skip("skipped because test could be flaky, see #35144") def test_to_non_blocking(self): - def _test_to_non_blocking(a, non_blocking): - stream = torch.cuda.current_stream() - with torch.cuda.stream(stream): - b = a.to('cuda', non_blocking=non_blocking) - self.assertEqual(stream.query(), not non_blocking) - stream.synchronize() - self.assertEqual(a, b) + stream = torch.cuda.current_stream() - # 10MB copies - x = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(x, True) + def _test_to_non_blocking(a, non_blocking, dst): + torch.cuda.synchronize() + # Pushes an 0.1 second spin to stream so if the copy is non blocking, + # stream will almost surely be active when we query(). + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + b = a.to(device=dst, non_blocking=non_blocking) + self.assertEqual(stream.query(), not non_blocking) + stream.synchronize() + self.assertEqual(a, b) + self.assertTrue(b.is_pinned() == (non_blocking and dst == "cpu")) + + for dst, try_non_blocking in product(("cuda", "cpu"), (True, False)): + # Creates source on the opposite device from destination. + src = torch.randn(1000000, + device="cuda" if dst == "cpu" else "cpu", + pin_memory=True if dst == "cuda" else False) + _test_to_non_blocking(src, try_non_blocking, dst) - y = torch.ones(10000000, dtype=torch.uint8) - _test_to_non_blocking(y, False) + def test_to_cpu_blocking_by_default(self): + src = torch.randn(1000000, device="cuda") + torch.cuda.synchronize() + torch.cuda._sleep(int(100 * get_cycles_per_ms())) + dst = src.to(device="cpu") + self.assertEqual(torch.cuda.current_stream().query(), True) + self.assertEqual(src, dst) + self.assertFalse(dst.is_pinned()) def test_serialization_array_with_storage(self): x = torch.randn(5, 5).cuda() @@ -965,7 +978,6 @@ def test_streams_multi_gpu_eq(self): self.assertNotEqual(hash(s0), hash(s3)) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_streams_priority(self): low, high = torch.cuda.Stream.priority_range() s0 = torch.cuda.Stream(device=0, priority=low) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 67a9c8477e8b..0d6ee2e03bd6 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1564,7 +1564,7 @@ def test_proper_exit(self): # In all cases, all processes should end properly. if use_workers: exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill'] - persistent_workers = self.persistent_workers + persistent_workers = self.persistent_workers else: exit_methods = [None, 'loader_error', 'loader_kill'] persistent_workers = False @@ -1840,6 +1840,12 @@ def test_default_collate_shared_tensor(self): finally: _utils.worker._worker_info = old + def test_excessive_thread_creation_warning(self): + with self.assertWarnsRegex( + UserWarning, + r"excessive worker creation might get DataLoader running slow or even freeze"): + dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) + class StringDataset(Dataset): def __init__(self): diff --git a/test/test_fx.py b/test/test_fx.py index 4945ea857ede..05e5a821f4ef 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1079,6 +1079,26 @@ def foo(x, y): x, y = torch.randn(3, 4), torch.randn(3, 4) self.checkGraphModule(foo, (x, y)) + def test_direct_param_use(self): + class TransposeTest(torch.nn.Module): + def __init__(self): + super().__init__() + self.b = torch.nn.Parameter(torch.rand(4, 3)) + + def forward(self, x): + return self.b + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = TransposeTest() + + def forward(self, x): + return self.a.b, self.a.b.t(), self.a.b.view(12) + + traced = torch.fx.symbolic_trace(Foo()) + assert(all('constant' not in node.target for node in traced.graph.nodes)) + if __name__ == '__main__': run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 34065987a4b2..c292b0828417 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1,9 +1,12 @@ import torch from torch.fx.symbolic_trace import symbolic_trace +from torch.fx.graph_module import GraphModule from torch.fx.experimental import GraphManipulation from torch.fx.experimental.Partitioner import Partitioner, Device, PartitionerConfig from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase +from torch.fx.experimental.partitioner_utils import get_latency_of_one_partition, \ + NodeLatency class TestFXExperimental(JitTestCase): def test_find_single_partition(self): @@ -29,6 +32,7 @@ def forward(self, a, b): module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) + assert dag.nodes[0].logical_device_ids == [0] def test_size_based_partition(self): class TestModule(torch.nn.Module): @@ -62,40 +66,44 @@ def forward(self, a, b): module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert len(module_with_submodules.graph.nodes) == 7 + for i, node in enumerate(dag.nodes): + assert node.logical_device_ids == [i] def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() - self.linear_0 = torch.nn.Linear(4, 4) + self.linear = torch.nn.Linear(4, 4) - def forward(self, a, b): + def forward(self, a): + b = torch.rand(4) add_1 = a + b - c = self.linear_0(a) - add_2 = c + add_1 - return add_2 + linear_1 = self.linear(add_1) + add_2 = torch.rand(4) + a + add_3 = add_2 + linear_1 + return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) - b = torch.rand(4) GraphManipulation.get_size_of_all_nodes( traced, - [a, b] + [a] ) partitioner = Partitioner() devices = [ - Device('dev_0', 125, 0), - Device('dev_1', 125, 1), - Device('dev_2', 125, 2) + Device('dev_0', 120, 0), + Device('dev_1', 144, 1) ] - partitioner_config = PartitionerConfig(devices) + partitioner_config = PartitionerConfig(devices, is_sparse_nn=False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag - self.assertEqual(traced(a, b), module_with_submodules(a, b)) - assert len(module_with_submodules.graph.nodes) == 5 + self.assertEqual(traced(a), module_with_submodules(a)) + assert dag.nodes[0].logical_device_ids == [0] + assert dag.nodes[0].size_bytes == 80 + assert dag.nodes[1].logical_device_ids == [1] + assert dag.nodes[1].size_bytes == 144 def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): @@ -157,5 +165,52 @@ def forward(self, a, b, offset): self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) assert len(module_with_submodules.graph.nodes) == 24 + def test_partition_latency(self): + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, a): + add_1 = a + torch.rand(4) + add_2 = add_1 + torch.rand(4) + linear_1 = self.linear(add_1) + add_4 = add_2 + linear_1 + add_5 = add_2 + add_4 + return add_5 + + def get_node_to_latency_mapping(fx_module: GraphModule): + """Given a fx module, generate node latency for each node + based on the size of each node + """ + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + for node in fx_module.graph.nodes: + if node.op not in {'output', 'placeholder', 'get_attr'}: + if node.size_bytes.total_size == node.size_bytes.output_size: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, 2. * node.size_bytes.total_size) + else: + node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, node.size_bytes.output_size) + return node_to_latency_mapping + + m = TestModule() + traced = symbolic_trace(m) + a = torch.rand(4) + GraphManipulation.get_size_of_all_nodes(traced, [a]) + node_to_latency_mapping = get_node_to_latency_mapping(traced) + devices = [ + Device('dev_0', 200, 0), + Device('dev_1', 200, 0) + ] + partitioner = Partitioner() + partitioner_config = PartitionerConfig(devices, False) + ret = partitioner.partition_graph(traced, m, partitioner_config) + module_with_submodules = ret.module_with_submodules + self.assertEqual(traced(a), module_with_submodules(a)) + partitions = partitioner.partitions + partition_latency_0 = get_latency_of_one_partition(partitions[0], node_to_latency_mapping) + assert (128., 80., 160.) == partition_latency_0 + partition_latency_1 = get_latency_of_one_partition(partitions[1], node_to_latency_mapping) + assert (16., 32., 32) == partition_latency_1 + if __name__ == '__main__': run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index dbe3fff20245..378c88eaa1cf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -82,7 +82,7 @@ from itertools import product import itertools from textwrap import dedent -from typing import List, Dict, Optional, Tuple, Union +from typing import List, Dict, NamedTuple, Optional, Tuple, Union import inspect import math import functools @@ -413,6 +413,15 @@ def forward(self, x): self.assertEqual(origin_result, m3(input.cpu())) self.assertEqual(origin_result, m4(input.cuda(0))) + def test_trace_retains_train(self): + class M(torch.nn.Module): + def forward(self, x): + return x + m = M() + m.eval() + tm = torch.jit.trace(m, (torch.rand(3))) + self.assertEqual(tm.training, m.training) + @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") def test_restore_shared_storage_on_cuda(self): class Foo(torch.jit.ScriptModule): @@ -13297,7 +13306,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("None = prim::Constant") \ + FileCheck().check_count("prim::Closure_0", 2).check("None = prim::Constant") \ .check_next("return").run(g) code = dedent(''' @@ -13314,7 +13323,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Function_0", 2).check("int = prim::If") \ + FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ .run(g) code = dedent(''' @@ -13328,9 +13337,9 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) fc = FileCheck() - fc.check("prim::Function").check("(Tensor, None) = prim::TupleConstruct") + fc.check("prim::Closure").check("(Tensor, None) = prim::TupleConstruct") # Loop then two if's added in exit transform - fc.check("prim::Function").check("prim::Loop").check_count("prim::If", 2) + fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) fc.run(cu.loop_in_closure.graph) code = dedent(''' @@ -13796,6 +13805,23 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_namedtuple_type_inference(self): + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) + + def test_check_named_tuple_value(): + named_tuple = _AnnotatedNamedTuple(1) + return named_tuple.value + + self.checkScript(test_check_named_tuple_value, ()) + + def test_error(): + return _UnannotatedNamedTuple(1) + + with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " + r"for argument \'value\' but instead found type \'int\'."): + torch.jit.script(test_error) + def test_isinstance_dynamic(self): @torch.jit.script def foo(a): diff --git a/test/test_linalg.py b/test/test_linalg.py index 127c674e5b05..d7de3841ab65 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1,13 +1,15 @@ import torch import unittest import itertools +import warnings from math import inf, nan, isnan from random import randrange from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) + (instantiate_device_type_tests, dtypes, dtypesIfCUDA, + onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride) from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args from torch.autograd import gradcheck @@ -914,6 +916,126 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve(self, device, dtype): + def run_test(a_shape, dims): + a = torch.randn(a_shape, dtype=dtype, device=device) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + def test_tensorsolve_empty(self, device, dtype): + # Check for empty inputs. NumPy does not work for these cases. + a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device) + b = torch.empty(a.shape[:2], dtype=dtype, device=device) + x = torch.linalg.tensorsolve(a, b) + self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b) + + # TODO: once "solve_cuda" supports complex dtypes, they shall be added to above tests + @unittest.expectedFailure + @onlyCUDA + @skipCUDAIfNoMagma + @dtypes(torch.cfloat, torch.cdouble) + def test_tensorsolve_xfailed(self, device, dtype): + a_shape = (2, 3, 6) + a = torch.randn(a_shape, dtype=dtype, device=device) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + result = torch.linalg.tensorsolve(a, b) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(result, expected) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @dtypesIfCUDA(torch.float, torch.double) + @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}) + def test_tensorsolve_non_contiguous(self, device, dtype): + def run_test_permuted(a_shape, dims): + # check for permuted / transposed inputs + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a.movedim((0, 2), (-2, -1)) + self.assertFalse(a.is_contiguous()) + b = torch.randn(a.shape[:2], dtype=dtype, device=device) + b = b.t() + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + def run_test_skipped_elements(a_shape, dims): + # check for inputs with skipped elements + a = torch.randn(a_shape, dtype=dtype, device=device) + a = a[::2] + self.assertFalse(a.is_contiguous()) + b = torch.randn(a_shape[:2], dtype=dtype, device=device) + b = b[::2] + self.assertFalse(b.is_contiguous()) + result = torch.linalg.tensorsolve(a, b, dims=dims) + expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims) + self.assertEqual(result, expected) + + # check non-contiguous out + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + a_shapes = [(2, 3, 6), (3, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_permuted(a_shape, d) + + a_shapes = [(4, 3, 6), (6, 4, 4, 3)] + dims = [None, (0, 2)] + for a_shape, d in itertools.product(a_shapes, dims): + run_test_skipped_elements(a_shape, d) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32) + def test_tensorsolve_errors_and_warnings(self, device, dtype): + # tensorsolve expects the input that can be reshaped to a square matrix + a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + b = torch.randn(8, 4) + self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape)) + with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'): + torch.linalg.tensorsolve(a, b) + + # if non-empty out tensor with wrong shape is passed a warning is given + out = torch.empty_like(a) + b = torch.randn(6, 4) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.linalg.tensorsolve(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.linalg.tensorsolve(a, b, out=out) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/test/test_nn.py b/test/test_nn.py index 25960d817ab8..f692213e4659 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3055,6 +3055,13 @@ def test_embedding_functional(self): res_F = F.embedding(a, embeddings) self.assertEqual(res_old, res_F) + embed_old = torch.nn.Embedding(4, 3) + embed_old = embed_old.from_pretrained(embeddings, padding_idx=2) + res_old = embed_old(a) + res_F = F.embedding(a, embeddings, padding_idx=2) + + self.assertEqual(res_old, res_F) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' ' with instruction set support avx2 or newer.') @@ -10515,6 +10522,27 @@ def test(nonlinearity, *args, **kwargs): test('threshold', 3, 2) test('threshold', 3, 2, inplace=True) + def test_pooling_shape(self, device): + ''' Test the output shape calculation for pooling functions ''' + + # Checks output shape against expected for 1D, 2D and 3D + def check(expected_out_shape, sizes, *args, **kwargs): + for kernel in ['max', 'avg']: + for i in [1, 2, 3]: + if hasattr(torch.nn.functional, f'{kernel}_pool{i}d'): + op = getattr(torch.nn.functional, f'{kernel}_pool{i}d') + t = torch.randn(sizes[:i + 2], device=device) + self.assertEqual(op(t, *args, **kwargs).shape, expected_out_shape[:i + 2]) + + check((1, 1, 3, 3, 4), (1, 1, 5, 6, 7), kernel_size=1, stride=2, padding=0, ceil_mode=True) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=False) + check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=True) + + # Test case from issue https://github.com/pytorch/pytorch/issues/45357 + x = torch.randn(1, 1, 6, 7, device=device) + y = torch.nn.functional.max_pool2d(x, 1, stride=(2, 2), padding=0, ceil_mode=True) + self.assertEqual(y.size(), (1, 1, 3, 4)) + @onlyOnCPUAndCUDA # TODO: fix on XLA def test_adaptive_avg_pool2d_output_size_one(self, device): def helper(size, memory_format): @@ -10605,8 +10633,10 @@ def check(x, args, message): def test_max_pool1d_corner_cases(self, device, dtype): def check(x, args, expected): model = torch.nn.MaxPool1d(*args) - tensor = torch.tensor(x, device=device, dtype=dtype) - self.assertEqual(model(tensor), torch.tensor(expected, device=device, dtype=dtype)) + if isinstance(x, list): + x = torch.tensor(x, device=device, dtype=dtype) + expected = torch.tensor(expected, device=device, dtype=dtype) + self.assertEqual(model(x), expected) # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) check([[]], (1, None, 0, 1, False, False), [[]]) @@ -10618,7 +10648,7 @@ def check(x, args, expected): check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]]) check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]]) - empty_tensor = torch.empty((2, 0, 1), dtype=torch.float32) + empty_tensor = torch.empty((2, 0, 1), device=device, dtype=dtype) check(empty_tensor, (1, None, 0, 1, False, False), empty_tensor) @onlyCPU @@ -10628,8 +10658,7 @@ def test_max_pool1d(self, device, dtype): def check(x, *args, **kwargs): model = torch.nn.MaxPool1d(*args, **kwargs) ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True) - tensor = torch.tensor(x, device=device, dtype=dtype) - self.assertEqual(model(tensor), ref_model(tensor)[0]) + self.assertEqual(model(x), ref_model(x)[0]) sizes = [random.sample(range(8, 128), 3) for _ in range(3)] kernel_sizes = random.sample(range(1, 5), 3) @@ -10640,10 +10669,11 @@ def check(x, *args, **kwargs): for size, kernel_size, stride, dilation, ceil_mode in \ itertools.product(sizes, kernel_sizes, strides, dilations, ceil_modes): padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1) - check(torch.randn(size), kernel_size, stride, padding, dilation, ceil_mode=ceil_mode) + check(torch.randn(size, device=device, dtype=dtype), + kernel_size, stride, padding, dilation, ceil_mode=ceil_mode) # Non-contiguous test - tensor = torch.randn(5, 151, 33)[::2, ::3, ::2] + tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2] check(tensor, 3, 2, 1, 2, ceil_mode=True) check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True) @@ -10738,6 +10768,15 @@ def fn(weight): fn = fn_wrapper(device) _assertGradAndGradgradChecks(self, fn, (weight, )) + def fn_wrapper(device): + def padding_fn(weight): + inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device) + return torch.nn.functional.embedding(inp, weight, padding_idx=1) + return padding_fn + + fn = fn_wrapper(device) + _assertGradAndGradgradChecks(self, fn, (weight, )) + def test_embedding_scalar_weight_error(self, device): indices = torch.rand(2, 2, device=device).long() weight = torch.tensor(1.0, device=device) @@ -10834,6 +10873,8 @@ def test_embedding_padding_idx(self, device, dtype): embedding.zero_grad() self.assertEqual(after, pre) + # Test fails on Vg20 + @skipCUDAIfRocm @dtypesIfCUDA(torch.half, torch.float) @dtypes(torch.float) def test_softmax_results(self, device, dtype): @@ -11433,6 +11474,8 @@ def test_embedding_max_norm_device(self, device, dtype): self.assertEqual(output[1], output[2]) self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) + # Test fails on Vg20 + @skipCUDAIfRocm @onlyCUDA @dtypes(torch.half, torch.float) def test_softmax(self, device, dtype): diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py index 8a106d7860d1..a6d37c9d52e2 100644 --- a/test/test_op_aliases.py +++ b/test/test_op_aliases.py @@ -6,6 +6,7 @@ from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, skipCPUIfNoLapack, skipCUDAIfNoMagma, onlyCPU) +import collections # Information for generating an alias test # NOTE: ending the alias_name with an underscore will interpret the test @@ -150,6 +151,8 @@ def __init__(self, AliasInfo('true_divide_', torch.Tensor.true_divide_, 'div_', torch.Tensor.div_, lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,), decorators=(onlyCPU,)), + AliasInfo('row_stack', torch.row_stack, 'vstack', torch.vstack, + lambda d: ((torch.randn(20, device=d), torch.randn(20, device=d)))), ) # Placeholder test class for validating that aliases are correctly @@ -157,6 +160,14 @@ def __init__(self, class TestOpNormalization(JitTestCase): pass + +# Clone input tensor and sequence of Tensors +def clone_inp(inp): + if isinstance(inp, collections.Sequence): + return list(map(torch.clone, inp)) + else: + return inp.clone() + # Generates alias tests and adds them to the specified class (cls) def create_alias_tests(cls): for info in alias_infos: @@ -180,10 +191,18 @@ def _fn(t): arg_string = ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(alias_name=info.alias_name, args=arg_string) else: - fn_template = ''' - def _fn(t): + is_input_tensor_list = isinstance(info.get_input(device), collections.Sequence) + # For sequence of Tensors, annotate the type to be List[Tensor] + if is_input_tensor_list: + fn_template = ''' + def _fn(t: List[Tensor]): return op(t{args}) - ''' + ''' + else: + fn_template = ''' + def _fn(t): + return op(t{args}) + ''' arg_string = ", " + ', '.join((str(arg) for arg in info.get_args(device))) script = fn_template.format(args=arg_string) @@ -192,8 +211,8 @@ def _fn(t): # Acquires and checks the graph remaps the alias inp = info.get_input(device) - scripted(inp.clone()) - graph = scripted.graph_for(inp.clone()) + scripted(clone_inp(inp)) + graph = scripted.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Checks that tracing converts aliases @@ -203,9 +222,9 @@ def _fn(t): def _fn(t, info=info, args=args): return info.alias_op(t, *args) - traced = torch.jit.trace(_fn, (inp.clone(),)) - traced(inp.clone()) - graph = traced.graph_for(inp.clone()) + traced = torch.jit.trace(_fn, (clone_inp(inp),)) + traced(clone_inp(inp)) + graph = traced.graph_for(clone_inp(inp)) FileCheck().check(info.original_name).check_not(info.alias_name).run(graph) # Applies decorators @@ -223,10 +242,10 @@ def _test_alias_computation(self, device, info=info): inp = info.get_input(device) args = info.get_args(device) - alias_input = inp.clone() + alias_input = clone_inp(inp) alias_result = alias_op(alias_input, *args) - original_input = inp.clone() + original_input = clone_inp(inp) original_result = alias_op(original_input, *args) self.assertEqual(alias_input, original_input, atol=0, rtol=0) diff --git a/test/test_quantization.py b/test/test_quantization.py index 26d3ee020983..682d9baff68e 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -61,9 +61,15 @@ from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401 # 3. GraphModule based graph mode quantization -from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 -from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +try: + from quantization.test_quantize_fx import TestFuseFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 + from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 +except ImportError: + # In FBCode we separate FX out into a separate target for the sake of dev + # velocity. These are covered by a separate test target `quantization_fx` + pass # Tooling: numeric_suite from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401 diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index b0777c7fa12a..a40c77bdcbf4 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -702,6 +702,23 @@ def test_eye(self, device): for dtype in torch.testing.get_all_dtypes(): if dtype == torch.bfloat16: continue + # Test the RuntimeError is raised when either m or n is a negative number + for n, m in ((-1, 1), (1, -1), (-1, -1)): + with self.assertRaisesRegex(RuntimeError, 'must be greater or equal to'): + torch.eye(n, m, device=device, dtype=dtype) + + # Test when the `m` parameter is not provided + for n in (3, 5, 7): + res1 = torch.eye(n, device=device, dtype=dtype) + naive_eye = torch.zeros(n, n, dtype=dtype, device=device) + naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1) + self.assertEqual(naive_eye, res1) + + # Check eye_out outputs + res2 = torch.empty(0, device=device, dtype=dtype) + torch.eye(n, out=res2) + self.assertEqual(res1, res2) + for n, m in product([3, 5, 7], repeat=2): # Construct identity using diagonal and fill res1 = torch.eye(n, m, device=device, dtype=dtype) @@ -1182,7 +1199,7 @@ def seed(generator): self.assertTrue((res1 < 6).all().item()) self.assertTrue((res1 >= 0).all().item()) - @dtypes(torch.half, torch.float, torch.double, + @dtypes(torch.half, torch.float, torch.bfloat16, torch.double, torch.complex32, torch.complex64, torch.complex128) def test_randn(self, device, dtype): SIZE = 100 diff --git a/test/test_torch.py b/test/test_torch.py index f1b0b23b1518..92f04549992a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -26,7 +26,7 @@ from torch.testing._internal.common_methods_invocations import tri_tests_args, run_additional_tri_tests, \ _compare_trilu_indices from torch.testing._internal.common_utils import \ - (TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, run_tests, + (TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ASAN, TEST_WITH_ROCM, run_tests, skipIfNoLapack, suppress_warnings, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, load_tests, slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, BytesIOContext, @@ -5993,6 +5993,41 @@ def test_diagonal_multidim(self, device, dtype): self.assertEqual(expected.shape, result.shape) self.assertEqual(expected, result) + def _test_trace(self, device, dtype, legacy): + def test(shape): + tensor = make_tensor(shape, device, dtype, low=-9, high=9) + diag = tensor.diag() + if legacy: + # NB: trace on cpu doesn't do type promotion... #47127 + expected_dtype = dtype + else: + expected_dtype = tensor.sum().dtype + expected_dtype = torch_to_numpy_dtype_dict[expected_dtype] + + result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype) + expected = torch.tensor(result, device=device) + self.assertEqual(tensor.trace(), expected) + + shapes = ( + [10, 1], + [1, 10], + [100, 100], + [20, 100], + [100, 20], + ) + for shape in shapes: + test(shape) + + @onlyCPU + @dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False)) + def test_trace_legacy(self, device, dtype): + self._test_trace(device, dtype, legacy=True) + + @onlyCUDA + @dtypes(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) + def test_trace(self, device, dtype): + self._test_trace(device, dtype, legacy=False) + @onlyCPU @dtypes(torch.float) def test_broadcast_tensors(self, device, dtype): @@ -10536,6 +10571,23 @@ def check(op, a, args, key): check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) + # Discontiguous and strided tensors + a = torch.arange(12, device=device) + self.assertEqual(a[::2].median(), torch.tensor(4, device=device)) + self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device)) + + a.resize_(3, 4) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device)) + self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device)) + + a.resize_(2, 3, 2) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + @onlyOnCPUAndCUDA @dtypes(torch.float, torch.double) @@ -13690,6 +13742,8 @@ def test_binary_op_mem_overlap(self, device, dtype): ("atan2", True, True, 'cuda'), ("hypot", True, True, 'cpu'), ("hypot", True, True, 'cuda'), + ("igamma", True, True, 'cpu'), + ("igamma", True, True, 'cuda'), ("nextafter", True, True, 'cpu'), ("nextafter", True, True, 'cuda'), ("le", True, True, 'cpu'), @@ -13930,10 +13984,11 @@ def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') @dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False))) def test_complex_scalar_pow_tensor(self, device, dtype): - complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j] - tensor = torch.rand(100).to(dtype=dtype, device=device) + complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j] + exp = make_tensor((100,), device, dtype, low=-2, high=2) + exp[0] = exp[10] = exp[20] = 0 for base in complexes: - self._test_pow(base, tensor) + self._test_pow(base, exp) @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') def test_tensor_pow_tensor(self, dev): @@ -15539,7 +15594,8 @@ def test_orgqr_errors(self, device): ((10,), (2,), r"'input' should be 2 dimensional"), ((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"), ((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"), - ((0, 0), (0,), r"'input' should not be empty") + ((0, 0), (0,), r"'input' should not be empty"), + ((2, 2), (2, 0,), r"'tau' should not be empty") ] for a_size, tau_size, error_regex in test_cases: a = torch.rand(*a_size, device=device) @@ -17472,6 +17528,70 @@ def test_hypot(self, device, dtype): expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) self.assertEqual(actual, expected) + def _helper_test_igamma(self, loglo, loghi, device, dtype): + exp1 = 2.71828182846 + vec1 = torch.logspace(loglo, loghi, steps=500, base=exp1, + dtype=torch.float64, device=device).unsqueeze(-1) + vec1 = vec1.to(dtype) + inputs = [ + (vec1, vec1.transpose(0, 1)), + (vec1, vec1), # for large number, it should approach 0.5 + (vec1, 0.5 * vec1), # test for considerable ratio + (vec1, 2.0 * vec1), + (vec1[::2, :], vec1[::2, :]), # contiguous/discontiguous tests + (vec1[::2, :], vec1[:vec1.shape[0] // 2, :]), + (vec1[:vec1.shape[0] // 2, :], vec1[::2, :]), + ] + half_prec = dtype in [torch.bfloat16, torch.float16] + for input0, input1 in inputs: + actual = torch.igamma(input0, input1) + if half_prec: + input0 = input0.to(torch.float) + input1 = input1.to(torch.float) + expected = scipy.special.gammainc(input0.cpu().numpy(), input1.cpu().numpy()) + expected = torch.from_numpy(expected).to(dtype) + self.assertEqual(actual, expected) + + @skipCUDAIfRocm # see issue https://github.com/pytorch/pytorch/issues/46531 + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + @onlyOnCPUAndCUDA + def test_igamma_common(self, device, dtype): + # test igamma for reasonable range of values + loglo = -4 # approx 0.018 + loghi = 4 # approx 54.6 + self._helper_test_igamma(loglo, loghi, device, dtype) + + @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) + @dtypes(torch.float32, torch.float64) + @onlyOnCPUAndCUDA + def test_igamma_edge_cases(self, device, dtype): + tkwargs = {"dtype": dtype, "device": device} + infs = torch.zeros((3,), **tkwargs) + float("inf") + zeros = torch.zeros((3,), **tkwargs) + ones = torch.ones((3,), **tkwargs) + zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + nans = torch.zeros((3,), **tkwargs) + float("nan") + inpouts = [ + # (a , x), out + ((zeros, small_to_inf), ones), + ((small_to_inf, zeros), zeros), + ((infs, zero_to_large), zeros), + ((zero_to_large, infs), ones), + ((zeros, zeros), nans), + ((infs, infs), nans), + ((-small_to_inf, small_to_inf), nans), + ] + for inputs, output in inpouts: + input0, input1 = inputs + calc = torch.igamma(input0, input1) + if torch.all(torch.isnan(output)): + self.assertTrue(torch.all(torch.isnan(calc))) + else: + self.assertEqual(calc, output) + @dtypes(torch.int64, torch.float64) def test_remainder_edge_cases(self, device, dtype): # Test variations of negative values used as input @@ -17495,8 +17615,8 @@ def test_remainder_edge_cases(self, device, dtype): r = a.remainder(b) r_expected = torch.tensor([0, 0, 0, 0, -3, 3, -2, 2] * 10000, dtype=dtype, device=device) self.assertEqual(r, r_expected) - # Test nan cases + a = torch.tensor([-34, 0, 34] * 20000, dtype=dtype, device=device) b = torch.zeros(3 * 20000, dtype=dtype, device=device) self.assertTrue(torch.isnan(a.remainder(b)).all()) @@ -19194,8 +19314,12 @@ def verify_against_numpy(t): def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype): # Test error for non-tuple argument + t = torch.randn(10) + with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): + torch_fn(t) + # Test error for a single array with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): - torch_fn(torch.randn(10)) + torch_fn((t)) # Test 0-D num_tensors = random.randint(1, 5) @@ -19245,25 +19369,41 @@ def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + torch.testing.get_all_complex_dtypes())) - def test_hstack(self, device, dtype): - self._test_special_stacks(1, 1, torch.hstack, np.hstack, device, dtype) + def test_hstack_column_stack(self, device, dtype): + ops = ((torch.hstack, np.hstack), (torch.column_stack, np.column_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(1, 1, torch_op, np_op, device, dtype) + + # Test torch.column_stack with combinations of 1D and 2D tensors input + one_dim_tensor = torch.arange(0, 10).to(dtype=dtype, device=device) + two_dim_tensor = torch.arange(0, 100).to(dtype=dtype, device=device).reshape(10, 10) + inputs = two_dim_tensor, one_dim_tensor, two_dim_tensor, one_dim_tensor + torch_result = torch.column_stack(inputs) + + np_inputs = [input.cpu().numpy() for input in inputs] + np_result = np.column_stack(np_inputs) + + self.assertEqual(np_result, + torch_result) @onlyOnCPUAndCUDA @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + torch.testing.get_all_complex_dtypes())) - def test_vstack(self, device, dtype): - self._test_special_stacks(0, 2, torch.vstack, np.vstack, device, dtype) - for i in range(5): - # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) - n = random.randint(1, 10) - input_a = self._generate_input((n,), dtype, device, with_extremal=False) - input_b = self._generate_input((1, n), dtype, device, with_extremal=False) - torch_input = [input_a, input_b] - np_input = [input.cpu().numpy() for input in torch_input] - actual = torch.vstack(torch_input) - expected = np.vstack(np_input) - self.assertEqual(actual, expected) + def test_vstack_row_stack(self, device, dtype): + ops = ((torch.vstack, np.vstack), (torch.row_stack, np.row_stack)) + for torch_op, np_op in ops: + self._test_special_stacks(0, 2, torch_op, np_op, device, dtype) + for i in range(5): + # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) + n = random.randint(1, 10) + input_a = self._generate_input((n,), dtype, device, with_extremal=False) + input_b = self._generate_input((1, n), dtype, device, with_extremal=False) + torch_input = [input_a, input_b] + np_input = [input.cpu().numpy() for input in torch_input] + actual = torch_op(torch_input) + expected = np_op(np_input) + self.assertEqual(actual, expected) @onlyOnCPUAndCUDA @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @@ -19359,6 +19499,17 @@ def compare_helper_(like_fn, t): tp = t.permute(p) compare_helper_(like_fn, tp) + @unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN") + @dtypes(*torch.testing.get_all_dtypes()) + def test_muldiv_scalar(self, device, dtype): + x = make_tensor((10, 3), device, dtype, low=None, high=None) + s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item() + y = torch.full_like(x, s) + self.assertEqual(x * s, x * y) + self.assertEqual(s * x, y * x) + self.assertEqual(x / s, x / y) + self.assertEqual(s / x, y / x) + # Tests that compare a device's computation with the (gold-standard) CPU's. class TestDevicePrecision(TestCase): exact_dtype = True @@ -21115,6 +21266,7 @@ class TestTorch(AbstractTestCases._TestTorchMixin): instantiate_device_type_tests(TestViewOps, globals()) instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') instantiate_device_type_tests(TestTensorDeviceOps, globals()) + instantiate_device_type_tests(TestTorchMathOps, globals(), only_for='cpu') if __name__ == '__main__': diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index c6efa8f0d90d..c023553de402 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -919,8 +919,8 @@ def test_unary_op_out_casting(self, device, dtypes): t = torch.tensor((1), dtype=dtypes[0], device=device) out = torch.empty(0, dtype=dtypes[1], device=device) - ops = (torch.neg, torch.floor, torch.ceil, torch.cos, torch.erf) - float_only_ops = {torch.floor, torch.ceil, torch.cos, torch.erf} + ops = (torch.neg, torch.floor, torch.ceil, torch.erf) + float_only_ops = {torch.floor, torch.ceil, torch.erf} real_only_ops = {torch.floor, torch.ceil, torch.erf} for op in ops: if dtypes[0] is not dtypes[1]: diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 6d4dc91ff5bd..9f3353376913 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -288,7 +288,7 @@ def test_reference_numerics(self, device, dtype, op): # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) # while NumPy computes in float16 self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=1e-4, atol=1e-3) + exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2) continue self.assertEqualHelper(actual, expected, msg, dtype=dtype, exact_dtype=exact_dtype) diff --git a/third_party/fbgemm b/third_party/fbgemm index 23cb1db72b03..5b7566f412aa 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 23cb1db72b03e29984eefe58c5c99d733a85435d +Subproject commit 5b7566f412aaeaab2c97338948f25e0ef0b8ac4b diff --git a/third_party/nccl/nccl b/third_party/nccl/nccl index 033d799524fb..cd5a9b73c302 160000 --- a/third_party/nccl/nccl +++ b/third_party/nccl/nccl @@ -1 +1 @@ -Subproject commit 033d799524fb97629af5ac2f609de367472b2696 +Subproject commit cd5a9b73c3028d2496666201588111a8c8d84878 diff --git a/third_party/pybind11 b/third_party/pybind11 index 25abf7efba0b..59a2ac2745d8 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 25abf7efba0b2990f5a6dfb0a31bc65c0f2f4d17 +Subproject commit 59a2ac2745d8a57ac94c6accced73620d59fb844 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fa03db91694c..2aef2b36a6fc 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -165,11 +165,11 @@ self: grad * -((-self * self + 1).rsqrt()).conj() - name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - self: grad - other: maybe_multiply(grad, alpha) + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) - name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - self: grad + self: handle_r_to_c(self.scalar_type(), grad) - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta) @@ -434,7 +434,7 @@ self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad - name: exp(Tensor self) -> Tensor - self: grad * result + self: grad * result.conj() - name: exp2(Tensor self) -> Tensor self: grad * result * M_LN2 @@ -540,6 +540,10 @@ - name: i0(Tensor self) -> Tensor self: not_implemented("i0") +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + - name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor self: index_backward(zeros_like(self), indices, grad) indices: TensorList() @@ -995,11 +999,11 @@ self: std_backward(result, grad, self, dim, unbiased, keepdim) - name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor - self: grad - other: -grad * alpha + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), -grad * alpha.conj()) - name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor - self: grad + self: handle_r_to_c(self.scalar_type(), grad) - name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: -grad * alpha @@ -1182,7 +1186,7 @@ weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) - name: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - grad_output: embedding_dense_double_backward(grad, indices) + grad_output: embedding_dense_double_backward(grad, indices, padding_idx) indices: non_differentiable - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 7b4b0ece8da6..661694f3d6ba 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -20,6 +20,7 @@ get_py_variable_methods, op_name, ) +import argparse import textwrap from .gen_autograd import load_aten_declarations diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index da937a4377fa..2783eb644bc6 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -302,7 +302,8 @@ def main(): parser.add_argument('autograd', metavar='AUTOGRAD', help='path to autograd directory') args = parser.parse_args() - gen_autograd(args.declarations, args.out, args.autograd) + gen_autograd(args.declarations, args.out, args.autograd, + SelectiveBuilder.get_nop_selector()) if __name__ == '__main__': diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 2e822e68a998..9223896666de 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -701,9 +701,17 @@ def group_overloads(declarations, is_python_method): result = [] for x, dictionary in sorted(grouped.items()): if 'base' not in dictionary: + candidates = [] + non_out_name = dictionary['out']['operator_name'] + for declaration in declarations: + if declaration['name'] == non_out_name and not declaration['deprecated']: + signature = get_python_signature(declaration, is_python_method, skip_outputs=True) + candidates.append(signature) raise RuntimeError( - "'base' not in dictionary for {}. keys are {}".format( - x, list(dictionary.keys()))) + "While identifying overloads, we found an out schema {} without a corresponding non-out variant. " + "We expected the non-out variant to have schema: \n- {}\nPlease check that you spelled the schema " + "correctly in native_functions.yaml. We discovered the following candidate(s): \n" + .format(dictionary['signature'], x) + "\n".join("- {}".format(candidate) for candidate in candidates)) result.append(dictionary) return sort_declarations(result) @@ -871,7 +879,6 @@ def go(f: NativeFunction) -> PythonSignature: src_args: Dict[str, PythonArgument] = {a.name: PythonArgument( name=a.name, type=a.type, - cpp_type_str=a.cpp_type_str, default=None, default_init=None, ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)} diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 570943755d2e..62c130854f10 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -166,7 +166,8 @@ 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', 'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal', - 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_' + 'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_', + 'exp', 'nonzero' } # Some operators invalidate the grad_accumulator. Let's reset it. @@ -284,9 +285,6 @@ grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); """) -CALL_DEFAULT = CodeTemplate("""\ -TypeDefault::${type_wrapper_name}(${args})""") - CALL_DISPATCH_VIA_NAMESPACE = CodeTemplate("""\ at::${api_name}(${unpacked_args})""") @@ -808,7 +806,7 @@ def emit_trace_body(declaration): def emit_body(declaration): - strategy = dispatch_strategy(declaration) + assert dispatch_strategy(declaration) == 'use_derived' arguments = declaration['arguments'] returns = declaration['returns'] @@ -865,8 +863,7 @@ def find_args_with_derivatives(differentiable_inputs): requires_derivative = ( base_name not in DONT_REQUIRE_DERIVATIVE and name not in DONT_REQUIRE_DERIVATIVE and - len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0 and - strategy == 'use_derived') + len(differentiable_inputs) > 0 and len(differentiable_outputs) > 0) if func is not None and not requires_derivative: raise RuntimeError('ERROR: derivative ignored for {} -- specified an autograd function without derivative' @@ -1150,28 +1147,20 @@ def enforce_same_tensorimpl_and_storage(env, call): def emit_call(env, tie_return_values): combined = nested_dict(env, declaration) - if strategy == 'use_derived': - # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch - # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure - # the baseType operations still dispatch to non-Variable type, even if the arguments passed - # in are now Variables. - # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. - base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) - if not modifies_arguments and not returns_void: - call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( - base_type_call=base_type_call) - - call += wrap_output(tie_return_values, 'tmp') - else: - call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( - base_type_call=base_type_call) + # We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + base_type_call = emit_dispatch_call(combined['api_name'], 'self_', combined['unpacked_args']) + if not modifies_arguments and not returns_void: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call) + + call += wrap_output(tie_return_values, 'tmp') else: - args = maybe_unwrap_optional_tensors(declaration, declaration['arguments'], declaration['args']) - - call = CALL_DEFAULT.substitute(declaration, args=args) - if not modifies_arguments and not returns_void: - call = '{} = {}'.format(tie_return_values, call) - call = call + ';' + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call) call = enforce_same_tensorimpl_and_storage(env, call) return call @@ -1211,16 +1200,14 @@ def emit_increment_version(): declare_returned_variables, tie_return_values, get_return_value = format_return_variables(declaration) - if strategy != 'use_type': - body.extend(unpack_args(env, declaration)) + body.extend(unpack_args(env, declaration)) if requires_derivative: body.extend(emit_check_inplace()) body.extend(setup_derivative(differentiable_inputs)) body.append(declare_returned_variables) body.append(emit_call(env, tie_return_values)) - if strategy == 'use_derived': - body.extend(emit_increment_version()) + body.extend(emit_increment_version()) if requires_derivative: # set_flags has to appear after version_counter, because rebase_history # requires that the counter is incremented before it is called diff --git a/tools/autograd/templates/TraceType.cpp b/tools/autograd/templates/TraceType.cpp index d08c1e3cc5aa..3ac52ed08edc 100644 --- a/tools/autograd/templates/TraceType.cpp +++ b/tools/autograd/templates/TraceType.cpp @@ -1,6 +1,5 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" -#include #include #include "torch/csrc/autograd/function.h" diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 079427cd97dc..ba2f99369f8d 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,7 +1,6 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" #include "torch/csrc/autograd/FunctionsManual.h" -#include #include // ${generated_comment} diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 86fc691c298c..65f5ec1c6903 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -41,7 +41,7 @@ def libtorch_generated_sources(gencode_pattern): "autograd/generated/TraceType_4.cpp", ]] -# copied from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/f99a693cd9ff7a9b5fdc71357dac66b8192786d3/aten/src/ATen/core/CMakeLists.txt jit_core_headers = [ "torch/csrc/utils/memory.h", "torch/csrc/WindowsTorchApiMacro.h", @@ -69,7 +69,7 @@ jit_core_sources = [ "torch/csrc/jit/frontend/source_range.cpp", ] -# copied from https://github.com/pytorch/pytorch/blob/master/tools/cpp_build/torch/CMakeLists.txt +# copied from https://github.com/pytorch/pytorch/blob/0bde610c14b92d351b968a0228df29e92442b1cc/torch/CMakeLists.txt # There are some common files used in both internal lite-interpreter and full-jit. Making a separate # list for the shared files. @@ -545,6 +545,7 @@ libtorch_python_core_sources = [ libtorch_python_distributed_core_sources = [ "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", + "torch/csrc/distributed/c10d/python_comm_hook.cpp", "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/reducer.cpp", ] diff --git a/tools/codegen/api/python.py b/tools/codegen/api/python.py index f04a648aade8..bb02407004ab 100644 --- a/tools/codegen/api/python.py +++ b/tools/codegen/api/python.py @@ -1,5 +1,6 @@ from tools.codegen.api.types import * import tools.codegen.api.cpp as cpp +import tools.codegen.local as local from tools.codegen.gen import pythonify_default from tools.codegen.model import * @@ -175,11 +176,6 @@ class PythonArgument: name: str type: Type - - # Consistent with 'type' for most cases, except for some TensorOptions fields - # which are hardcoded (see 'signature()' method). - cpp_type_str: str - default: Optional[str] # Used to generate the default init expr for some PythonArgParser outputs, e.g.: @@ -193,29 +189,15 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. def argument_str(self, *, method: bool = False) -> str: - name = self.name - typename = _simple_type(self.cpp_type_str) - - # [old codegen] TODO: remove this and make optional types in simple_type - # to be consistent across tensor and other types after make Tensor? be - # optional instead of undefined - if self.type.is_nullable() and '?' not in typename: - typename = f'{typename}?' + type_str = argument_type_str(self.type) # s/self/input/ outside method bindings # [old codegen] TODO: remove this? doesn't rename in codegen, it's just # for the parse string - if name == 'self' and typename == 'Tensor' and not method: + name = self.name + if name == 'self' and type_str == 'Tensor' and not method: name = 'input' - # add list size annotation - size = self.size - if size is not None: - if typename.endswith('?'): - typename = f'{typename[:-1]}[{size}]?' - else: - typename = f'{typename}[{size}]' - # add default if self.default is not None: default = { @@ -223,15 +205,9 @@ def argument_str(self, *, method: bool = False) -> str: 'c10::nullopt': 'None', '{}': 'None', }.get(self.default, self.default) - return f'{typename} {name}={default}' + return f'{type_str} {name}={default}' else: - return f'{typename} {name}' - - @property - def size(self) -> Optional[int]: - l = self.type.is_list_like() - return l.size \ - if l is not None and l.size is not None and str(l.elem) != 'bool' else None + return f'{type_str} {name}' @dataclass(frozen=True) class PythonOutArgument(PythonArgument): @@ -252,7 +228,6 @@ def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgu return PythonOutArgument( name=outputs[0].name, type=outputs[0].type, - cpp_type_str=outputs[0].cpp_type_str, default='None', default_init=None, outputs=outputs, @@ -263,7 +238,6 @@ def from_outputs(outputs: Tuple[PythonArgument, ...]) -> Optional['PythonOutArgu return PythonOutArgument( name='out', type=ListType(BaseType(BaseTy.Tensor), size), - cpp_type_str='TensorList', default='None', default_init=None, outputs=outputs, @@ -368,7 +342,6 @@ def signature_str(self, *, skip_outputs: bool = False) -> str: class DispatchLambdaArgument: name: str type_str: str - cpp_type_str: str is_out_arg: bool # To pass PyObjects arguments to C++ function (via the lambda wrapper), @@ -424,28 +397,6 @@ class DispatchLambdaArgumentExprs: # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The original simple_type is derived from the 'type' field in Declaration.yaml, -# which is generated from the C++ argument type, following some seemingly -# artificial rules: -# -# Concrete C++ types are preferred in most cases, e.g.: -# 'IntArrayRef' instead of 'int[]' -# 'int64_t' instead of 'int' -# -# Constant/Reference annotation and optional field are handled specially, e.g.: -# 'ScalarType?' instead of 'c10::optional' -# 'Tensor' instead of 'const Tensor &' / 'Tensor &' -# -# TODO: This needs to be consistent with python_arg_parser - can we simplify it? -def _simple_type(cpp_type_str: str) -> str: - simple_type = cpp_type_str.replace(' &', '').replace('const ', '') - opt_match = re.match(r'c10::optional<(.+)>', simple_type) - if opt_match: - typename = opt_match.group(1) - # HACK: 'Layout?' needs to be hardcoded to 'Layout'! - simple_type = f'{typename}?' if typename != 'Layout' else 'Layout' - return simple_type - def _cpp_signature(f: NativeFunction, *, method: bool = False) -> cpp.CppSignature: return CppSignatureGroup.from_schema(f.func, method=method).signature @@ -459,6 +410,49 @@ def has_tensor_options(f: NativeFunction) -> bool: # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +def argument_type_str(t: Type) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return 'Tensor' + elif t.name == BaseTy.int: + return 'int64_t' + elif t.name == BaseTy.float: + return 'double' + elif t.name == BaseTy.str: + return 'std::string' + elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, + BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage, + BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat, + BaseTy.Dimname, BaseTy.Stream, BaseTy.ConstQuantizerPtr]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + elem = argument_type_str(t.elem) + if elem == 'Layout': + # TODO: fix this special case in PythonArgParser? + return 'Layout' + else: + return f'{elem}?' + + elif isinstance(t, ListType): + if str(t.elem) == 'bool': + assert t.size is not None + return f'std::array' + elif str(t.elem) == 'int': + return f'IntArrayRef[{t.size}]' if t.size is not None else 'IntArrayRef' + elif str(t.elem) == 'Tensor': + return f'TensorList[{t.size}]' if t.size is not None else 'TensorList' + elif str(t.elem) == 'Tensor?': + # TODO: clone the old codegen behavior but does it make sense? + return 'TensorList?' + elif str(t.elem) == 'Dimname': + return f'DimnameList[{t.size}]' if t.size is not None else 'DimnameList' + elem = argument_type_str(t.elem) + return f'ArrayRef<{elem}>' + + raise RuntimeError(f'unrecognized type {repr(t)}') + def argument(cpp_arg: CppArgument) -> PythonArgument: a = cpp_arg.argument if not isinstance(a, Argument): @@ -468,7 +462,6 @@ def argument(cpp_arg: CppArgument) -> PythonArgument: return PythonArgument( name=a.name, type=a.type, - cpp_type_str=cpp_arg.type, # TODO: directly translate a.default to python default default=str(pythonify_default(cpp.default_expr(a.default, a.type))) if a.default is not None else None, @@ -515,55 +508,37 @@ def signature(f: NativeFunction, *, method: bool = False) -> PythonSignature: has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) name: str = cpp.name(f.func) - has_options_arg = has_tensor_options(f) - - is_like_function = name.endswith('_like') or f.category_override == 'like' - is_new_function = name.startswith('new_') or f.category_override == 'new' - is_factory_function = has_tensor_return and not has_tensor_input_arg \ - or f.category_override == 'factory' - is_like_or_new_function_with_options = \ - (is_like_function or is_new_function) and has_options_arg + is_factory_function = f.category_override == 'factory' or (has_tensor_return and not has_tensor_input_arg) + is_like_or_new_function = f.category_override in ('new', 'like') or name.startswith('new_') or name.endswith('_like') tensor_options_args: List[PythonArgument] = [] - if is_factory_function or has_options_arg: + if is_factory_function or is_like_or_new_function: tensor_options_args.append(PythonArgument( name='dtype', - cpp_type_str='const ScalarType &', type=BaseType(BaseTy.ScalarType), default=_dtype_default_type_hack(name), - default_init='self.scalar_type()' - if is_like_or_new_function_with_options else None, + default_init='self.scalar_type()' if is_like_or_new_function else None, )) - - if is_factory_function or is_like_or_new_function_with_options: tensor_options_args.append(PythonArgument( name='layout', - cpp_type_str='c10::optional', - type=BaseType(BaseTy.Layout), + type=OptionalType(BaseType(BaseTy.Layout)), default='torch.strided', - default_init='layout_from_backend(self.options().backend())' - if is_like_or_new_function_with_options else None, + default_init='layout_from_backend(self.options().backend())' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( name='device', - cpp_type_str='const Device &', type=BaseType(BaseTy.Device), default='None', - default_init='self.device()' - if is_like_or_new_function_with_options else None, + default_init='self.device()' if is_like_or_new_function else None, )) tensor_options_args.append(PythonArgument( name='pin_memory', - cpp_type_str='bool', type=BaseType(BaseTy.bool), default='False', default_init=None, )) - - if has_tensor_return and (is_factory_function or is_like_function or is_new_function): tensor_options_args.append(PythonArgument( name='requires_grad', - cpp_type_str='bool', type=BaseType(BaseTy.bool), default='False', default_init=None, @@ -660,7 +635,6 @@ def dispatch_lambda_arg(cpp_arg: CppArgument) -> DispatchLambdaArgument: return DispatchLambdaArgument( name=cpp_arg.name, type_str=type_str, - cpp_type_str=cpp_arg.type, is_out_arg=is_out_arg, ) @@ -750,107 +724,91 @@ def cpp_dispatch_exprs(f: NativeFunction, method: bool, *, # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# TODO: should emit these unpack methods directly from Type to avoid -# indirect translation via cpp_type_str. -UNPACK_METHODS = { - 'const Tensor &': 'tensor', - 'Tensor &': 'tensor', - 'Stream': 'stream', - 'c10::optional': 'optionalTensor', - 'const c10::optional&': 'optionalTensor', - 'c10::optional': 'generator', - 'Storage': 'storage', - 'Storage &': 'storage', - 'const ScalarType &': 'scalartype', - 'const Device &': 'device', - 'c10::optional': 'toDimnameListOptional', - 'c10::optional': 'scalartypeOptional', - 'c10::optional': 'layoutOptional', - 'c10::optional': 'memoryformatOptional', - 'c10::optional': 'scalarOptional', - 'c10::optional': 'intlistOptional', - 'c10::optional': 'toInt64Optional', - 'c10::optional': 'toBoolOptional', - 'c10::optional': 'toDoubleOptional', - 'c10::optional>': 'doublelistOptional', - 'ArrayRef': 'doublelist', - 'IntArrayRef': 'intlist', - 'Scalar': 'scalar', - 'ScalarType': 'scalartype', - 'Dimname': 'dimname', - 'DimnameList': 'dimnamelist', - 'TensorList': 'tensorlist', - 'int64_t': 'toInt64', - 'bool': 'toBool', - 'double': 'toDouble', - 'std::string': 'string', - 'c10::optional': 'stringOptional', -} - -UNPACK_WITH_SIZE_METHODS = { - 'TensorList': 'tensorlist_n<{}>', - 'DimnameList': 'dimnamelist', - 'IntArrayRef': 'intlist', - 'c10::optional': 'intlistOptional', -} - -UNPACK_WITH_DEFAULT_METHODS = { - 'const ScalarType &': 'scalartypeWithDefault', - 'const Device &': 'deviceWithDefault', - 'c10::optional': 'layoutWithDefault', -} +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method(t: Type, has_default: bool) -> str: + if has_default and str(t) not in ('ScalarType', 'Device', 'Layout?'): + raise RuntimeError(f'type \'{t}\' does not supported unpacking with default') + + if isinstance(t, BaseType): + if t.name in [BaseTy.Tensor, BaseTy.Stream, BaseTy.Storage, + BaseTy.Scalar, BaseTy.Dimname]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return 'scalartypeWithDefault' if has_default else 'scalartype' + elif t.name == BaseTy.Device: + return 'deviceWithDefault' if has_default else 'device' + elif t.name == BaseTy.int: + return 'toInt64' + elif t.name == BaseTy.bool: + return 'toBool' + elif t.name == BaseTy.float: + return 'toDouble' + elif t.name == BaseTy.str: + return 'string' + + elif isinstance(t, OptionalType): + if str(t.elem) == 'Tensor': + if local.use_c10_dispatcher().dispatcher_uses_new_style(): + return 'optionalTensor' + else: + return 'tensor' + + elif isinstance(t.elem, BaseType): + if t.elem.name in [BaseTy.ScalarType, BaseTy.Scalar, + BaseTy.int, BaseTy.bool, + BaseTy.float, BaseTy.str]: + # Regular cases: append 'Optional' to elem's unpacking method + return arg_parser_unpack_method(t.elem, False) + 'Optional' + elif t.elem.name == BaseTy.MemoryFormat: + return 'memoryformatOptional' + elif t.elem.name == BaseTy.Generator: + return 'generator' + elif t.elem.name == BaseTy.Layout: + return 'layoutWithDefault' if has_default else 'layoutOptional' + + elif isinstance(t.elem, ListType): + if str(t.elem.elem) == 'int': + # accept definite size + return 'intlistOptional' + elif str(t.elem) == 'float[]': + return 'doublelistOptional' + elif str(t.elem) == 'Dimname[]': + return 'toDimnameListOptional' + + elif isinstance(t, ListType): + if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?': + # accept and use definite size + if t.size is not None: + return f'tensorlist_n<{t.size}>' + else: + return 'tensorlist' + elif str(t.elem) == 'Dimname': + # accept definite size + return 'dimnamelist' + elif str(t.elem) == 'int': + # accept definite size + return 'intlist' + elif str(t) == 'float[]': + return 'doublelist' + + raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser') # Return RHS expression for python argument using PythonArgParser output. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' def arg_parser_output_expr( - arg_index: int, a: PythonArgument, la: Optional[DispatchLambdaArgument] + arg_index: int, a: PythonArgument ) -> PythonArgParserOutputExpr: - # The same python signature (and python schema string) is usually - # associated with two aten C++ functions: the base version and the - # out-place variant. Usually the two functions have the same set of - # arguments - of course, except for the output arguments. But in some - # cases they might have slightly different C++ argument types - - # affected by the 'use_c10_dispatcher' state. - # - # More specially, 'Tensor?' type can be translated into - # either 'const c10::optional&' or 'const Tensor &'. - # Unfortunately, this difference can affect how we should access arg - # parser output. The former expects '_r.optionalTensor(i)' while the - # latter expects '_r.tensor(i)'. - # - # Because of this subtle difference, we cannot solely use the shared - # python signature to determine the RHS expr for both C++ variants. - # We could create and use each C++ variant's own python signature, - # but we have to fix the argument index difference between the two - # python signatures like the old codegen does - and it feels wrong as - # technically there is only one shared python signature! - # - # So here we pass in the lambda wrapper's argument and use it to - # decide what PythonArgParser unpack method to use. - # - # TODO: this seems too complicated - maybe we can simplify after full - # c10 dispatch migration? - typename = la.cpp_type_str \ - if a.name != 'out' and la is not None else a.cpp_type_str - - if a.default_init is not None: - # Note: only introduced in tensor_options_args - if typename not in UNPACK_WITH_DEFAULT_METHODS: - raise RuntimeError( - f'type \'{typename}\' is not supported in default_init') - unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename] - expr = f'_r.{unpack_with_default}({arg_index}, {a.default_init})' - elif a.size is not None: - if typename not in UNPACK_WITH_SIZE_METHODS: - raise RuntimeError( - f'type \'{typename}\' with definite size ({a.size}) is not supported') - unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(a.size) - expr = f'_r.{unpack_with_size}({arg_index})' - else: - unpack = UNPACK_METHODS.get(typename) - if unpack is None: - raise RuntimeError(f'type \'{typename}\' is not supported') - expr = f'_r.{unpack}({arg_index})' + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method(a.type, has_default) + default = f', {a.default_init}' if has_default else '' + expr = f'_r.{unpack_method}({arg_index}{default})' return PythonArgParserOutputExpr( name=a.name, @@ -863,17 +821,14 @@ def arg_parser_output_expr( def arg_parser_output_exprs( ps: PythonSignature, f: NativeFunction, *, method: bool ) -> Dict[str, PythonArgParserOutputExpr]: - lambda_args = dispatch_lambda_args(ps, f, method=method) - lambda_args_map = dict(map(lambda a: (a.name, a), lambda_args)) - return {e.name: e for i, a in enumerate(ps.arguments()) - for e in (arg_parser_output_expr(i, a, lambda_args_map.get(a.name)), )} + for e in (arg_parser_output_expr(i, a), )} -# argument name to 'simple_type' for scattered tensor options fields +# argument name to type for scattered tensor options fields TENSOR_OPTIONS_FIELDS = { 'dtype': 'ScalarType', 'device': 'Device', - 'layout': 'Layout', + 'layout': 'Layout?', 'pin_memory': 'bool', 'requires_grad': 'bool', } @@ -909,7 +864,7 @@ def dispatch_lambda_exprs( ]) for i, out_arg in enumerate(a.outputs): lambda_args_exprs[out_arg.name] = f'out[{i}]' - elif a.cpp_type_str == 'c10::optional': + elif str(a.type) == 'Dimname[]?': # [old codegen] # TODO: make this part of something more general, or get rid of it. # optional> are special. The PythonArgParser returns an @@ -937,9 +892,9 @@ def dispatch_lambda_exprs( if a.name not in TENSOR_OPTIONS_FIELDS: raise RuntimeError( f'{f.func}: unrecognized tensor options field \'{a.name}\' in python binding arguments') - if _simple_type(a.cpp_type_str) != TENSOR_OPTIONS_FIELDS.get(a.name): + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): raise RuntimeError( - f'{f.func}: unrecognized type \'{_simple_type(a.cpp_type_str)}\' for tensor options field \'{a.name}\'') + f'{f.func}: unrecognized type \'{str(a.type)}\' for tensor options field \'{a.name}\'') if not all(map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())): raise RuntimeError( f'{f.func}: incomplete tensor options args: {tensor_options_args_names}') diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index e495244a183e..433f3cf5fd67 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -341,12 +341,8 @@ class NativeExpr: class NativeArgument: type: str name: str - # Native function arguments have defaults for some reasons (e.g., - # the function prototypes in CPUType.h are defaulted). There isn't - # really any good reason to do this, as these functions are only - # ever called from a context where all defaulted arguments are - # guaranteed to be given explicitly. - # TODO: Remove this + # Native function arguments have defaults to make it a little + # easier to call them directly to bypass dispatch. default: Optional[str] argument: Union[Argument, TensorOptionsArguments] diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index af358d9d1b7c..134f7163518b 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -158,11 +158,10 @@ def cpp_string(s: str) -> str: # Dispatch keywords in native_functions.yaml that support all backends. KEYWORD_ALL_BACKENDS = ('DefaultBackend', 'Math') -# Generates {dispatch}Type.cpp and {dispatch}Type.h (e.g., CPUType.cpp -# and CPUType.h). This function is also reused to implement per-operator -# registration. It also generates TypeDefault.cpp and TypeDefault.h when -# dispatch target is for all backends (dispatch is None or dispatch in -# KEYWORD_ALL_BACKENDS). +# Generates {dispatch}Type.cpp (e.g., CPUType.cpp). This function is also +# reused to implement per-operator registration. It also generates +# TypeDefault.cpp when dispatch target is for all backends (dispatch is None or +# dispatch in KEYWORD_ALL_BACKENDS). # # {dispatch}Type.cpp # - The primary function of this file is to register all of the @@ -179,36 +178,29 @@ def cpp_string(s: str) -> str: # (as would be the case if you directly registered native:: # functions). # -# {dispatch}Type.h -# - In principle, this file shouldn't exist at all; historically, -# it existed so that we could directly access these functions -# outside of the registration API for the implementation of -# static dispatch. Should be deleted now! -# # This function is also used for a secondary purpose: the registration # logic is also reused to implement per-operator registration. def compute_type_method( dispatch: Optional[str], *, + # TODO: Give more precise type Union[Literal[Target.DEFINITION, + # Target.REGISTRATION]]; requires Literal from typing_extensions + # which we don't have a dep for yet. target: Target, # Selector object to determine which operators to generate # registration code for. - selector: SelectiveBuilder, - # Only valid for generating registrations. If True, only generate - # def() invocations (for schema registration); do not generate - # any impl() invocations for, e.g., catch-all kernels - def_only: bool = False + selector: SelectiveBuilder ) -> Callable[[NativeFunction], Optional[str]]: - if def_only: - assert target is Target.REGISTRATION and dispatch is None + if dispatch is None: + assert target is Target.REGISTRATION @with_native_function def func(f: NativeFunction) -> Optional[str]: + # Has to be here as mypy won't transfer asserts into closures + assert target is not Target.DECLARATION + if dispatch is not None: - if f.dispatch is None or dispatch not in f.dispatch: - return None - else: - if f.dispatch is not None and target is not Target.REGISTRATION: + if dispatch not in f.dispatch: return None op_name = f"aten::{f.func.name}" @@ -219,24 +211,18 @@ def func(f: NativeFunction) -> Optional[str]: returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) args_str = ', '.join(map(str, args)) - dispatch_to_all_backends = dispatch is None or dispatch in KEYWORD_ALL_BACKENDS + dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS - if target is Target.DECLARATION: - return f"{returns_type} {name}({args_str});" - elif target is Target.DEFINITION: - if f.dispatch is None: - cpp_name = cpp.name(f.func) - impl_name = f"at::native::{cpp_name}" - else: - assert dispatch is not None - impl_name = f"at::native::{f.dispatch[dispatch]}" + if target is Target.DEFINITION: + assert dispatch is not None + impl_name = f"at::native::{f.dispatch[dispatch]}" args_exprs_str = ', '.join(a.name for a in args) return_kw = " return " cuda_guard = "" - if dispatch_to_all_backends or 'CUDA' in dispatch or 'Vulkan' == dispatch: # type: ignore + if dispatch_to_all_backends or 'CUDA' in dispatch: self_args = (a for a in f.func.arguments if a.name == "self") # There is precedence for which argument we use to do @@ -261,7 +247,7 @@ def func(f: NativeFunction) -> Optional[str]: # TODO: There is probably a simpler version of this that # works just as well. - if f.device_guard and (dispatch_to_all_backends or 'Vulkan' == dispatch) and has_tensor_options: + if f.device_guard and dispatch_to_all_backends and has_tensor_options: cuda_guard = cuda_guard_from_tensor_options elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options: cuda_guard = f"""\ @@ -284,20 +270,18 @@ def func(f: NativeFunction) -> Optional[str]: """ elif target is Target.REGISTRATION: - dispatcher_sig = DispatcherSignature.from_schema(f.func) - - if dispatch_to_all_backends: - type_name = f'TypeDefault::{name}' + if dispatch is None: + return f'm.def({cpp_string(str(f.func))});\n' + elif f.manual_kernel_registration: + return None else: - type_name = f'{dispatch}Type::{name}' + if dispatch_to_all_backends: + type_name = f'TypeDefault::{name}' + else: + type_name = f'{dispatch}Type::{name}' - # def registration only happens in TypeDefault - def_registration = "" - if dispatch is None: - def_registration = f'm.def({cpp_string(str(f.func))});\n' + dispatcher_sig = DispatcherSignature.from_schema(f.func) - impl_registration = "" - if not def_only and not f.manual_kernel_registration and (dispatch is not None or f.dispatch is None): # Figure out which signature the function is if local.use_c10_dispatcher() is UseC10Dispatcher.full: payload = f"TORCH_FN({type_name})" @@ -321,9 +305,7 @@ def func(f: NativeFunction) -> Optional[str]: if dispatch is not None: payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n" - impl_registration = f'm.impl("{f.func.name}",\n{payload});\n' - - return f"{def_registration}{impl_registration}" + return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(target) @@ -439,10 +421,7 @@ def compute_aten_op(f: NativeFunction) -> str: # actual kernel definitions we keep in aten/src/ATen/native/ @with_native_function def compute_native_function_declaration(f: NativeFunction) -> List[str]: - if f.dispatch is None: - ns = [cpp.name(f.func)] - else: - ns = list(f.dispatch.values()) + ns = list(f.dispatch.values()) rs = [] # Sometimes a function name shows up multiple times; only generate @@ -763,8 +742,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object: is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \ and Variant.method not in f.variants - # Having only Math in dispatch section is equivalent to no dispatch section. - is_abstract = f.dispatch is not None and set(f.dispatch.keys()) != set({'Math'}) # type ignore + is_abstract = f.dispatch.keys() != {'Math'} return OrderedDict([ ('name', cpp.name(f.func)), @@ -803,7 +781,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('device_guard', f.device_guard), ('with_gil', False), ('deprecated', False), - ('has_math_kernel', f.dispatch is not None and 'Math' in f.dispatch), + ('has_math_kernel', 'Math' in f.dispatch), ]) @with_native_function @@ -814,8 +792,9 @@ def compute_registration_declarations(f: NativeFunction) -> str: args_str = ', '.join(map(str, args)) comment_data : Dict[str, str] = { 'schema': f'aten::{f.func}', - 'dispatch': str(f.dispatch is not None), - 'default': str(f.dispatch is not None and any(k in f.dispatch for k in KEYWORD_ALL_BACKENDS)) + # TODO: What exactly is the semantics of the 'dispatch' field? + 'dispatch': str(f.dispatch.keys() != {'Math'}), + 'default': str(any(k in f.dispatch for k in KEYWORD_ALL_BACKENDS)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} """ @@ -931,11 +910,6 @@ def main() -> None: '--rocm', action='store_true', help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') - # TODO: remove this, we should just unconditionally generate Vulkan - parser.add_argument( - '--vulkan', - action='store_true', - help='Generate Vulkan backend functions') # TODO: --op_registration_whitelist will be removed when all call-sites # for gen.py are moved over to using the operator YAML file for mobile # custom build. @@ -1005,48 +979,36 @@ def make_file_manager(install_dir: str) -> FileManager: cuda_fm = make_file_manager(options.install_dir) extra_cuda_headers = '''\ -#include #include #include #include ''' if options.rocm: extra_cuda_headers = '''\ -#include #include #include #include ''' - backends = ["CPU", "SparseCPU", "MkldnnCPU", "CUDA", "SparseCUDA", "QuantizedCPU", "QuantizedCUDA"] - if options.vulkan: - backends.append("Vulkan") + backends = [ + "CPU", + "SparseCPU", + "MkldnnCPU", + "CUDA", + "SparseCUDA", + "QuantizedCPU", + "QuantizedCUDA", + ] if options.backend_whitelist: backends = [b for b in backends if b in options.backend_whitelist] for dispatch in backends: h_template = 'TypeDerived.h' cpp_template = 'TypeDerived.cpp' - # TODO: delete this special case - if 'Sparse' in dispatch: - cpp_template = 'SparseTypeDerived.cpp' fm = cuda_fm if 'CUDA' in dispatch else cpu_fm - fm.write_with_template(f'{dispatch}Type.h', h_template, lambda: { - 'Type': f'{dispatch}Type', - 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', # TODO: remove this - 'type_derived_method_declarations': list(mapMaybe( - compute_type_method(dispatch, target=Target.DECLARATION, selector=selector), - native_functions - )), - }) fm.write_with_template(f'{dispatch}Type.cpp', cpp_template, lambda: { 'Type': f'{dispatch}Type', - # TODO: remove this 'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', - # TODO: remove this - 'storage_tensor_headers': '#include ', - # TODO: remove this - 'Generator': 'CUDAGeneratorImpl' if 'CUDA' in dispatch else 'CPUGeneratorImpl', 'legacy_th_headers': '#include ' if dispatch == "CPU" else '#include ' if dispatch == "CUDA" else @@ -1064,23 +1026,13 @@ def make_file_manager(install_dir: str) -> FileManager: }) del fm - cpu_fm.write('TypeDefault.h', lambda: { - 'type_method_declarations': - list(mapMaybe( - compute_type_method(None, target=Target.DECLARATION, selector=selector), - native_functions)) + - list(mapMaybe( - compute_type_method('Math', target=Target.DECLARATION, selector=selector), - native_functions)) + - list(mapMaybe( - compute_type_method('DefaultBackend', target=Target.DECLARATION, selector=selector), - native_functions)), - }) + schema_selector = selector + if options.force_schema_registration: + schema_selector = SelectiveBuilder.get_nop_selector() + + # TODO: split this file into separate files cpu_fm.write('TypeDefault.cpp', lambda: { 'type_method_definitions': - list(mapMaybe( - compute_type_method(None, target=Target.DEFINITION, selector=selector), - native_functions)) + list(mapMaybe( compute_type_method('Math', target=Target.DEFINITION, selector=selector), native_functions)) + @@ -1089,10 +1041,12 @@ def make_file_manager(install_dir: str) -> FileManager: native_functions)), 'function_registrations': list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, selector=selector), - native_functions)) + list(mapMaybe( - compute_type_method('Math', target=Target.REGISTRATION, selector=selector), - native_functions)), + compute_type_method(None, target=Target.REGISTRATION, selector=schema_selector), + native_functions)), + + 'math_function_registrations': list(mapMaybe( + compute_type_method('Math', target=Target.REGISTRATION, selector=selector), + native_functions)), 'default_backend_function_registrations': list(mapMaybe( compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector), @@ -1123,16 +1077,6 @@ def make_file_manager(install_dir: str) -> FileManager: list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)), }) - if options.force_schema_registration: - def computeSchemaRegister() -> Dict[str, object]: - schema_registrations = list(mapMaybe( - compute_type_method(None, target=Target.REGISTRATION, selector=SelectiveBuilder.get_nop_selector(), def_only=True), - native_functions)) - return { - 'schema_registrations': schema_registrations, - } - cpu_fm.write('SchemaRegister.cpp', computeSchemaRegister) - cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions])) cpu_fm.write('RegistrationDeclarations.h', lambda: { 'registration_declarations': [compute_registration_declarations(f) for f in native_functions], diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 56605b7130db..95cb7438a814 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -98,13 +98,15 @@ class NativeFunction: # registrations don't participate in codegen-based selective build! manual_kernel_registration: bool - # Distinguish between a missing dispatch dict (historically, this - # means to register a catch-all kernel) and a present but empty - # dispatch dict (this means register nothing; arguably, this should - # subsume manual_kernel_registration). + # A mapping of dispatch keys to names of functions implementing + # them. In native_functions.yaml, the dispatch entry is optional; in that + # case, that is equivalent to having written: + # + # dispatch: + # Math: $operator_name # # TODO: str key could be replaced with more explicit enum - dispatch: Optional[Dict[str, str]] + dispatch: Dict[str, str] # The location in the YAML file were this native function entry was # defined. This is for conveniently reporting error messages! @@ -162,9 +164,8 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': raw_dispatch = e.pop('dispatch', None) assert raw_dispatch is None or isinstance(raw_dispatch, dict), e - dispatch: Optional[Dict[str, str]] = None + dispatch: Dict[str, str] = {} if raw_dispatch is not None: - dispatch = {} for ks, v in raw_dispatch.items(): if ks == '__line__': continue # not worth tracking line numbers for dispatch entries @@ -172,9 +173,14 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': assert isinstance(v, str), e for k in ks.split(","): dispatch[k.strip()] = v + else: + from tools.codegen.api import cpp + dispatch['Math'] = cpp.name(func) - # Throws if both DefaultBackend and Math are provided - assert not (dispatch is not None and 'DefaultBackend' in dispatch and 'Math' in dispatch) + assert not ('DefaultBackend' in dispatch and 'Math' in dispatch), \ + "cannot specify both DefaultBackend and Math on a single kernel; each " \ + "strictly subsumes the other. If you wanted to provide an explicit autograd " \ + "implementation, specify DefaultBackend; otherwise specify Math only" e.pop('__line__') assert not e, f"leftover entries: {e}" diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index b5ea62ff29bb..8129f38eb0ef 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -2,6 +2,7 @@ import os import subprocess from pathlib import Path +from distutils.util import strtobool def get_sha(): try: @@ -27,7 +28,7 @@ def get_torch_version(sha=None): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate torch/version.py from build and environment metadata.") - parser.add_argument("--is_debug", type=bool, help="Whether this build is debug mode or not.") + parser.add_argument("--is_debug", type=strtobool, help="Whether this build is debug mode or not.") parser.add_argument("--cuda_version", type=str) parser.add_argument("--hip_version", type=str) @@ -47,7 +48,7 @@ def get_torch_version(sha=None): # NB: This is not 100% accurate, because you could have built the # library code with DEBUG, but csrc without DEBUG (in which case # this would claim to be a release build when it's not.) - f.write("debug = {}\n".format(repr(args.is_debug))) + f.write("debug = {}\n".format(repr(bool(args.is_debug)))) f.write("cuda = {}\n".format(repr(args.cuda_version))) f.write("git_version = {}\n".format(repr(sha))) f.write("hip = {}\n".format(repr(args.hip_version))) diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index 8d1fb00fc8d2..f2896fac7f22 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -535,7 +535,8 @@ def main(): parser.add_argument('template_path', metavar='TEMPLATE_PATH', help='path to templates directory') args = parser.parse_args() - gen_unboxing_wrappers(args.declarations, args.out, args.template_path) + gen_unboxing_wrappers(args.declarations, args.out, args.template_path, + SelectiveBuilder.get_nop_selector()) if __name__ == '__main__': diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4396155c73ea..a1c800debb59 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj): the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: - annotations.append(torch._C.TensorType.get()) + annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, fields, annotations diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 33c9c6a0e307..12dd77497454 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1542,6 +1542,20 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.i0` """) +add_docstr_all('igamma', + r""" +igamma(other) -> Tensor + +See :func:`torch.igamma` +""") + +add_docstr_all('igamma_', + r""" +igamma_(other) -> Tensor + +In-place version of :meth:`~Tensor.igamma` +""") + add_docstr_all('indices', r""" indices() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ba02b5fc3110..4f0de3335414 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1818,6 +1818,40 @@ def merge_dicts(*dicts): Alias for :func:`torch.clamp`. """.format(**common_args)) +add_docstr(torch.column_stack, + r""" +column_stack(tensors, *, out=None) -> Tensor + +Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + +Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` +in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + +Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + +Keyword args: + {out} + +Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + +""".format(**common_args)) + add_docstr(torch.complex, r""" complex(real, imag, *, out=None) -> Tensor @@ -3316,6 +3350,47 @@ def merge_dicts(*dicts): """.format(**common_args)) +add_docstr(torch.igamma, + r""" +igamma(input, other, *, out=None) -> Tensor + +Computes the regularized lower incomplete gamma function: + +.. math:: + \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt + +where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive +and at least one is strictly positive. +If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`. +:math:`\Gamma(\cdot)` in the equation above is the gamma function, + +.. math:: + \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt. + +See :func:`torch.lgamma` for a related function. + +Supports :ref:`broadcasting to a common shape ` +and float inputs. + +.. note:: + The backward pass with respect to :attr:`input` is not yet supported. + Please open an issue on PyTorch's Github to request it. + +""" + r""" +Args: + input (Tensor): the first non-negative input tensor + other (Tensor): the second non-negative input tensor + +Keyword args: + {out} + +Example:: + + >>> a = torch.igamma(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([0.3528, 0.5665, 0.7350]) + +""".format(**common_args)) + add_docstr(torch.index_select, r""" index_select(input, dim, index, *, out=None) -> Tensor @@ -6679,6 +6754,12 @@ def merge_dicts(*dicts): torch.uint8 """) +add_docstr(torch.row_stack, + r""" +row_stack(tensors, *, out=None) -> Tensor + +Alias of :func:`torch.vstack`. +""".format(**common_args)) add_docstr(torch.round, r""" diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 4e44536d931c..eabb07fd9de0 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -71,6 +71,7 @@ def backward( retain_graph: Optional[bool] = None, create_graph: bool = False, grad_variables: Optional[_TensorOrTensors] = None, + inputs: Optional[Sequence[torch.Tensor]] = None, ) -> None: r"""Computes the sum of gradients of given tensors w.r.t. graph leaves. @@ -116,6 +117,11 @@ def backward( create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the attr::tensors. All the provided inputs must be leaf + Tensors. """ if grad_variables is not None: warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.") @@ -125,8 +131,11 @@ def backward( raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) " "arguments both passed to backward(). Please only " "use 'grad_tensors'.") + if inputs is not None and len(inputs) == 0: + raise RuntimeError("'inputs' argument to backward() cannot be empty.") tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) + inputs = tuple(inputs) if inputs is not None else tuple() grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors)) grad_tensors_ = _make_grads(tensors, grad_tensors_) @@ -134,8 +143,8 @@ def backward( retain_graph = create_graph Variable._execution_engine.run_backward( - tensors, grad_tensors_, retain_graph, create_graph, - allow_unreachable=True) # allow_unreachable flag + tensors, grad_tensors_, retain_graph, create_graph, inputs, + allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag def grad( @@ -213,7 +222,7 @@ def grad( return Variable._execution_engine.run_backward( outputs, grad_outputs_, retain_graph, create_graph, - inputs, allow_unused) + inputs, allow_unused, accumulate_grad=False) # This function applies in case of gradient checkpointing for memory diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index f7e48c3b682d..92e8a93c284e 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -59,7 +59,7 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala PyTypeObject* getPyTypeObject( const at::Storage& storage, - const caffe2::TypeMeta& dtype) { + const caffe2::TypeMeta dtype) { at::ScalarType scalarType = at::typeMetaToScalarType(dtype); auto attype = &at::getDeprecatedTypeProperties( at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())), @@ -106,7 +106,7 @@ THPLayout* getTHPLayout(at::Layout layout) { PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type) { + const caffe2::TypeMeta data_type) { auto type = getPyTypeObject(storage, data_type); auto obj = THPObjectPtr(type->tp_alloc(type, 0)); if (!obj) throw python_error(); diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 0877fb317cb3..d93d0e3b5cf5 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -29,7 +30,7 @@ void registerLayoutObject(THPLayout *thp_layout, at::Layout layout); PyObject* createPyObject( const at::Storage& storage, - const caffe2::TypeMeta& data_type); + const caffe2::TypeMeta data_type); at::Storage createStorage(PyObject* obj); bool isStorage(PyObject* obj); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index d09b3428e4f1..a5df6329030d 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -85,9 +85,9 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) THPObjectPtr types(PySequence_Fast(arg, "expected a sequence")); if (!types) return nullptr; - int num_classes = PySequence_Fast_GET_SIZE(types.get()); + auto num_classes = PySequence_Fast_GET_SIZE(types.get()); names.reserve(names.size() + num_classes); - for (size_t i = 0; i < num_classes; i++) { + for (Py_ssize_t i = 0; i < num_classes; i++) { PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i); THPUtils_assert(PyType_Check(obj), "expected a PyTypeObject"); PyTypeObject* type = (PyTypeObject*)obj; @@ -864,7 +864,30 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False)); #endif - auto defaultGenerator = at::detail::getDefaultCPUGenerator(); +// See note [Pybind11 ABI constants] +#define SET_STR_DEFINE(name) \ + ASSERT_TRUE(set_module_attr("_" # name, THPUtils_packString(name))) + +#ifdef PYBIND11_COMPILER_TYPE + SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); +#endif + +#ifdef PYBIND11_STDLIB + SET_STR_DEFINE(PYBIND11_STDLIB); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None)); +#endif + +#ifdef PYBIND11_BUILD_ABI + SET_STR_DEFINE(PYBIND11_BUILD_ABI); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None)); +#endif +#undef SET_STR_DEFINE + + const auto& defaultGenerator = at::detail::getDefaultCPUGenerator(); THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); // This reference is meant to be given away, so no need to incref here. ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 5ce90dcc972e..c0bae62510e6 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -28,6 +28,14 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +inline Tensor tensorsolve(const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve(self, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& self, const Tensor& other, optional dims) { + return torch::linalg_tensorsolve_out(result, self, other, dims); +} + } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -53,4 +61,22 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string o return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`. +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.tensorsolve +/// +/// Example: +/// ``` +/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4}); +/// auto b = torch::randn(2*3, 4); +/// auto x = torch::linalg::tensorsolve(a, b); +/// ``` +inline Tensor tensorsolve(const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve(input, other, dims); +} + +inline Tensor& tensorsolve_out(Tensor& result, const Tensor& input, const Tensor& other, optional dims) { + return detail::tensorsolve_out(result, input, other, dims); +} + }} // torch::linalg diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 0fb0e44e9450..d3752bce04cc 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2694,16 +2694,18 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { return at::constant_pad_nd(grad, negated_pad, 0); } -Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) { - // since first backward takes care of padding_idx - // and scaling by frequency, we don't need to worry - // about it here. +Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { + // since first backward takes care of scaling by frequency, + // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); // reshape gradient as per the shape of indices auto size = indices.sizes().vec(); size.push_back(-1); + if (padding_idx >= 0) { + gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); + } return gg_weight.view(size); } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 1f5ba99e83a2..46f26610c127 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -118,7 +118,7 @@ at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, cons at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet); at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape); -at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices); +at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx); at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad); at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity); diff --git a/torch/csrc/autograd/anomaly_mode.cpp b/torch/csrc/autograd/anomaly_mode.cpp index bbb76fba656f..e8afa6f8fc52 100644 --- a/torch/csrc/autograd/anomaly_mode.cpp +++ b/torch/csrc/autograd/anomaly_mode.cpp @@ -1,9 +1,77 @@ +#include +#include #include +#include +#include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { bool AnomalyMode::_enabled = false; +namespace { +std::mutex& get_anomaly_guard_lock() { + static std::mutex anomaly_guard_lock{}; + return anomaly_guard_lock; +} + +uint32_t& get_anomaly_counter() { + static uint32_t counter = 0; + return counter; +} +} // namespace + +DetectAnomalyGuard::DetectAnomalyGuard() { + TORCH_WARN_ONCE( + "This mode should be enabled only for debugging as the different tests will slow down your program execution."); + std::lock_guard lock(get_anomaly_guard_lock()); + uint32_t& counter = get_anomaly_counter(); + counter++; + AnomalyMode::set_enabled(true); +} + +DetectAnomalyGuard::~DetectAnomalyGuard() { + std::lock_guard lock(get_anomaly_guard_lock()); + uint32_t& counter = get_anomaly_counter(); + counter--; + AnomalyMode::set_enabled(counter > 0); +} + AnomalyMetadata::~AnomalyMetadata() = default; -}} +void AnomalyMetadata::store_stack() { + traceback_ = c10::get_backtrace(/* frames_to_skip */ 1); +} + +void AnomalyMetadata::print_stack(const std::string& current_node_name) { + TORCH_WARN( + "Error detected in ", + current_node_name, + ". ", + "Traceback of forward call that caused the error:\n", + traceback_); + + auto& cur_parent = parent_; + // if there is no "parent_" in metadata, then it means this metadata's node + // is the root and stop printing the traceback + while (cur_parent) { + auto parent_metadata = cur_parent->metadata(); + TORCH_WARN( + "\n\n", + "Previous calculation was induced by ", + cur_parent->name(), + ". " + "Traceback of forward call that induced the previous calculation:\n", + parent_metadata->traceback_); + // get the parent of this node, if this node is a root, pyparent is simply + // null + cur_parent = parent_metadata->parent_; + } +} + +void AnomalyMetadata::assign_parent(const std::shared_ptr& parent_node) { + parent_ = parent_node; +} + +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/anomaly_mode.h b/torch/csrc/autograd/anomaly_mode.h index 013600b230fc..a4e4210dabfe 100644 --- a/torch/csrc/autograd/anomaly_mode.h +++ b/torch/csrc/autograd/anomaly_mode.h @@ -21,12 +21,43 @@ struct TORCH_API AnomalyMode { static bool _enabled; }; +/// A RAII guard that enables Anomaly Detection Mode. +/// +/// Anomaly detection mode is useful for debugging problems happening +/// in the backward, such as unexpectedly modified tensors or NaNs +/// occuring in the backward. +/// +/// The enabling of anomaly mode is global - as soon as there is one +/// such guard, it is enabled for all computation and threads. It also +/// comes with a significant performance penalty. +/// +/// Example: +/// @code +/// auto x = torch::tensor({1.}, torch::requires_grad()); +/// { +/// torch::autograd::DetectAnomalyGuard detect_anomaly; +/// auto x = torch::tensor({5.0}, torch::requires_grad()); +/// auto y = x * x; +/// auto z = y * y; +/// y += 1; +/// z.backward(); +/// } +/// @endcode +class TORCH_API DetectAnomalyGuard { + public: + DetectAnomalyGuard(); + ~DetectAnomalyGuard(); +}; struct TORCH_API AnomalyMetadata { virtual ~AnomalyMetadata(); - virtual void store_stack() = 0; - virtual void print_stack(const std::string& current_node_name) = 0; - virtual void assign_parent(const std::shared_ptr& parent_node) = 0; + virtual void store_stack(); + virtual void print_stack(const std::string& current_node_name); + virtual void assign_parent(const std::shared_ptr& parent_node); + + private: + std::string traceback_; + std::shared_ptr parent_; }; }} diff --git a/torch/csrc/autograd/autograd.cpp b/torch/csrc/autograd/autograd.cpp index b8756ff1c7b4..9c2879aa460d 100644 --- a/torch/csrc/autograd/autograd.cpp +++ b/torch/csrc/autograd/autograd.cpp @@ -68,7 +68,8 @@ variable_list run_backward( bool keep_graph, bool create_graph, const variable_list& inputs, - bool allow_unused) { + bool allow_unused, + bool accumulate_grad) { size_t num_tensors = outputs.size(); edge_list roots; roots.reserve(num_tensors); @@ -104,7 +105,7 @@ variable_list run_backward( } variable_list grad_inputs = Engine::get_default_engine().execute( - roots, grad_outputs, keep_graph, create_graph, output_edges); + roots, grad_outputs, keep_graph, create_graph, accumulate_grad, output_edges); // check if grad_inputs contains None or not base on the allow_unused flag if (!inputs.empty() && !allow_unused) { size_t num_inputs = inputs.size(); @@ -129,7 +130,7 @@ void backward( if (!retain_graph) { retain_graph = create_graph; } - run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true); + run_backward(tensors, gradients, retain_graph.value(), create_graph, {}, /*allow_unused=*/true, /*accumulate_grad=*/true); } variable_list grad( @@ -144,7 +145,7 @@ variable_list grad( retain_graph = create_graph; } return run_backward( - outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused); + outputs, gradients, retain_graph.value(), create_graph, inputs, allow_unused, /*accumulate_grad=*/false); } } // namespace autograd diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index e952b0afc772..a0a608f5c36b 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -843,6 +843,7 @@ auto Engine::execute(const edge_list& roots, const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs) -> variable_list { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) validate_outputs(roots, const_cast(inputs), [](const std::string& msg) { @@ -867,7 +868,7 @@ auto Engine::execute(const edge_list& roots, compute_dependencies(graph_root.get(), *graph_task); if (!outputs.empty()) { - graph_task->init_to_execute(*graph_root, outputs); + graph_task->init_to_execute(*graph_root, outputs, accumulate_grad); } execute_with_graph_task(graph_task, graph_root); @@ -1079,16 +1080,21 @@ void Engine::add_thread_pool_task(const std::weak_ptr& graph_task) { thread_pool_shared_->work_.notify_one(); } -void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) { +void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad) { exec_info_[&graph_root].needed_ = true; int output_idx = 0; for (auto & output_edge : outputs) { Node *output = output_edge.function.get(); auto & info = exec_info_[output]; - if (!info.captures_) - info.captures_ = make_unique>(); - info.captures_->emplace_back(output_edge.input_nr, output_idx++); + if (accumulate_grad) { + info.needed_ = true; + } else { + if (!info.captures_) { + info.captures_ = make_unique>(); + } + info.captures_->emplace_back(output_edge.input_nr, output_idx++); + } } captured_vars_.resize(output_idx); @@ -1136,7 +1142,7 @@ void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs) { auto it = exec_info_.find(edge.function.get()); return it != exec_info_.end() && it->second.should_execute(); }); - exec_info_[frame.fn_].needed_ = needed; + exec_info_[frame.fn_].needed_ |= needed; stack.pop_back(); } } diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 0dde6e735d10..59c8844e0ac5 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -109,7 +109,7 @@ struct GraphTask: std::enable_shared_from_this { std::unordered_set leaf_streams; - void init_to_execute(Node& graph_root, const edge_list& outputs); + void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad); // The value of worker_device in the thread that created this task. // See Note [Reentrant backwards] @@ -272,6 +272,7 @@ struct TORCH_API Engine { const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs = {}); // Given a pre-populated GraphTask and GraphRoot, computes the backward pass @@ -284,7 +285,7 @@ struct TORCH_API Engine { std::shared_ptr graph_root); virtual std::unique_ptr make_anomaly_metadata() { - return nullptr; + return std::make_unique(); } // We pass cpu_ready_queue to evaluate_function, so that it knows diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 709296c350ef..cc9f04c053c2 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -87,13 +87,14 @@ variable_list PythonEngine::execute( const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs) { TORCH_CHECK(!PyGILState_Check(), "The autograd engine was called while holding the GIL. If you are using the C++ " "API, the autograd engine is an expensive operation that does not require the " "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'" ". If you are not using the C++ API, please report a bug to the pytorch team.") try { - return Engine::execute(roots, inputs, keep_graph, create_graph, outputs); + return Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs); } catch (python_error& e) { e.restore(); throw; @@ -128,14 +129,14 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg unsigned char create_graph = 0; PyObject *inputs = nullptr; unsigned char allow_unreachable = 0; - const char *accepted_kwargs[] = { + unsigned char accumulate_grad = 0; + const char *accepted_kwargs[] = { // NOLINT "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs", - "allow_unreachable", nullptr + "allow_unreachable", "accumulate_grad", nullptr }; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Ob", (char**)accepted_kwargs, - &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs, + &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad)) return nullptr; - THPUtils_assert(PyTuple_Check(tensors), "tensors argument is expected to " "be a tuple, but got %s", THPUtils_typename(tensors)); THPUtils_assert(PyTuple_Check(grad_tensors), "grad_tensors argument is " @@ -147,7 +148,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg "gradients", num_tensors, num_gradients); // The user either called autograd.backward(...) or autograd.grad(...) to get here - bool backward_api_called = inputs == nullptr; + bool backward_api_called = accumulate_grad; TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0, "backward() called inside torch.vmap. This is not supported, " "please call backward() outside torch.vmap or instead use " @@ -193,7 +194,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg } std::vector output_edges; - if (!backward_api_called) { + if (inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); output_edges.reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { @@ -210,7 +211,11 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg const auto output_nr = input_var->cdata.output_nr(); auto grad_fn = input_var->cdata.grad_fn(); if (!grad_fn) { - grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); + grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata); + } + if (accumulate_grad) { + THPUtils_assert(input_var->cdata.is_leaf(), + "One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor"); } THPUtils_assert(input_var->cdata.requires_grad(), "One of the differentiated Tensors does not require grad"); @@ -226,10 +231,10 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg { pybind11::gil_scoped_release no_gil; auto& engine = python::PythonEngine::get_python_engine(); - outputs = engine.execute(roots, grads, keep_graph, create_graph, output_edges); + outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges); } - if (!backward_api_called) { + if (!backward_api_called && inputs != nullptr) { int num_inputs = PyTuple_GET_SIZE(inputs); THPObjectPtr py_outputs {PyTuple_New(num_inputs)}; if (!py_outputs) return nullptr; diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h index 7d722d43d504..0a36a4529cf6 100644 --- a/torch/csrc/autograd/python_engine.h +++ b/torch/csrc/autograd/python_engine.h @@ -23,6 +23,7 @@ struct PythonEngine : public Engine { const variable_list& inputs, bool keep_graph, bool create_graph, + bool accumulate_grad, const edge_list& outputs = {}) override; std::shared_ptr execute_with_graph_task( diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index d642ef53101a..20f0e46304e4 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -279,7 +279,7 @@ void DistEngine::computeDependencies( // Create a dummy GraphRoot and run init_to_execute with it. GraphRoot dummyRoot(edges, {}); - graphTask->init_to_execute(dummyRoot, outputEdges); + graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false); for (auto& mapEntry : graphTask->exec_info_) { auto& execInfo = mapEntry.second; if (!execInfo.captures_) { diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index f9fa2bbd7e8f..4ee19888e4c2 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -85,54 +85,5 @@ void broadcast_coalesced( } } -PythonCommHook::~PythonCommHook() { - py::gil_scoped_acquire ag; - state_.dec_ref(); - hook_.dec_ref(); - // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor - // to decref on the PyObject again. - // See Note [Destructing py::object] in python_ivalue.h - state_.ptr() = nullptr; - hook_.ptr() = nullptr; -} - -c10::intrusive_ptr PythonCommHook::runHook( - GradBucket& bucket) { - py::gil_scoped_acquire acquire; - - py::object py_fut = hook_(state_, bucket); - - try { - return py_fut.cast>()->fut; - } catch (const py::cast_error& e) { - auto type = py_fut.get_type(); - auto errMsg = c10::str( - e.what(), - ". DDP communication hook's callback must return a " - "torch.futures.Future or torch._C.Future object, but got ", - type.attr("__module__").cast(), - ".", - type.attr("__qualname__").cast()); - throw std::runtime_error(errMsg); - } -} - -std::vector PythonCommHook::parseHookResult( - const c10::IValue& result) { - TORCH_INTERNAL_ASSERT( - result.isPyObject() || result.isTensorList(), - "expected the hook result is either a PyObject or TensorList"); - - if (result.isPyObject()) { - py::gil_scoped_acquire ag; - py::object obj = torch::jit::toPyObject(result); - auto value = torch::jit::toIValue( - obj, c10::ListType::create(c10::TensorType::get())); - - return value.toTensorVector(); - } - - return result.toTensorVector(); -} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/comm.h b/torch/csrc/distributed/c10d/comm.h index 20e0f948808d..1eb78b201569 100644 --- a/torch/csrc/distributed/c10d/comm.h +++ b/torch/csrc/distributed/c10d/comm.h @@ -1,8 +1,8 @@ #pragma once #include +#include #include -#include namespace c10d { @@ -48,7 +48,7 @@ class TORCH_PYTHON_API CommHookInterface { // Passes the input grad bucket to the registered communication hook. // Once the tensors in the bucket are ready, kicks off the hook asynchronously // and returns a future that holds the communication results. - virtual c10::intrusive_ptr runHook( + virtual c10::intrusive_ptr runHook( GradBucket& bucket) = 0; // Returns the resulting tensors once the communication hook result is ready. @@ -58,28 +58,6 @@ class TORCH_PYTHON_API CommHookInterface { const c10::IValue& result) = 0; }; -class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { - public: - // Takes a state and a callable hook. The inputs are Python objects. - // The state is passed to the hook in runHook method, and it can be used to - // maintain and update any state information during the execution of the hook. - // The hook performs user-specified processing and returns a future indicating - // asychronous communication of gradients. - PythonCommHook(py::object state, py::object hook) - : state_(std::move(state)), hook_(std::move(hook)) {} - - ~PythonCommHook() override; - - c10::intrusive_ptr runHook(GradBucket& bucket) override; - - std::vector parseHookResult(const c10::IValue& result) override; - - private: - // Only needed for stateful communication. - py::object state_; - py::object hook_; -}; - // This CppCommHook interface only requires implementing runHook method that // potentially uses a state. template diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.cpp b/torch/csrc/distributed/c10d/default_comm_hooks.cpp index 0f7a24acc40d..10da31bf0b03 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.cpp +++ b/torch/csrc/distributed/c10d/default_comm_hooks.cpp @@ -6,7 +6,7 @@ namespace c10d { -c10::intrusive_ptr AllReduceCommHook::runHook( +c10::intrusive_ptr AllReduceCommHook::runHook( GradBucket& bucket) { auto allreduce_work = state_->allreduce(bucket.getTensorsRef()); @@ -19,7 +19,7 @@ c10::intrusive_ptr AllReduceCommHook::runHook( return fut->then(div_by_process_group_size, fut->elementType()); } -c10::intrusive_ptr FP16CompressCommHook::runHook( +c10::intrusive_ptr FP16CompressCommHook::runHook( GradBucket& bucket) { auto& tensors = bucket.getTensorsRef(); for (auto& tensor : tensors) { diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.h b/torch/csrc/distributed/c10d/default_comm_hooks.h index 5e53e01ac688..dd141fe0eb58 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.h +++ b/torch/csrc/distributed/c10d/default_comm_hooks.h @@ -8,13 +8,13 @@ namespace c10d { class AllReduceCommHook : public CppCommHookInterface { ~AllReduceCommHook() override {} - c10::intrusive_ptr runHook(GradBucket& bucket) override; + c10::intrusive_ptr runHook(GradBucket& bucket) override; }; class FP16CompressCommHook : public CppCommHookInterface { ~FP16CompressCommHook() override {} - c10::intrusive_ptr runHook(GradBucket& bucket) override; + c10::intrusive_ptr runHook(GradBucket& bucket) override; }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 47a3ebabe941..27155d81e5c1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp new file mode 100644 index 000000000000..6b25018d38a3 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -0,0 +1,60 @@ +#include + +#include +#include +#include +#include + +namespace c10d { + +PythonCommHook::~PythonCommHook() { + py::gil_scoped_acquire ag; + state_.dec_ref(); + hook_.dec_ref(); + // Explicitly set state_ and hook_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + state_.ptr() = nullptr; + hook_.ptr() = nullptr; +} + +c10::intrusive_ptr PythonCommHook::runHook( + GradBucket& bucket) { + py::gil_scoped_acquire acquire; + + py::object py_fut = hook_(state_, bucket); + + try { + return py_fut.cast>()->fut; + } catch (const py::cast_error& e) { + auto type = py_fut.get_type(); + auto errMsg = c10::str( + e.what(), + ". DDP communication hook's callback must return a " + "torch.futures.Future or torch._C.Future object, but got ", + type.attr("__module__").cast(), + ".", + type.attr("__qualname__").cast()); + throw std::runtime_error(errMsg); + } +} + +std::vector PythonCommHook::parseHookResult( + const c10::IValue& result) { + TORCH_INTERNAL_ASSERT( + result.isPyObject() || result.isTensorList(), + "expected the hook result is either a PyObject or TensorList"); + + if (result.isPyObject()) { + py::gil_scoped_acquire ag; + py::object obj = torch::jit::toPyObject(result); + auto value = torch::jit::toIValue( + obj, c10::ListType::create(c10::TensorType::get())); + + return value.toTensorVector(); + } + + return result.toTensorVector(); +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_comm_hook.h b/torch/csrc/distributed/c10d/python_comm_hook.h new file mode 100644 index 000000000000..e38ba096460f --- /dev/null +++ b/torch/csrc/distributed/c10d/python_comm_hook.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace c10d { + +class TORCH_PYTHON_API PythonCommHook : public CommHookInterface { + public: + // Takes a state and a callable hook. The inputs are Python objects. + // The state is passed to the hook in runHook method, and it can be used to + // maintain and update any state information during the execution of the hook. + // The hook performs user-specified processing and returns a future indicating + // asychronous communication of gradients. + PythonCommHook(py::object state, py::object hook) + : state_(std::move(state)), hook_(std::move(hook)) {} + + ~PythonCommHook() override; + + c10::intrusive_ptr runHook(GradBucket& bucket) override; + + std::vector parseHookResult(const c10::IValue& result) override; + + private: + // Only needed for stateful communication. + py::object state_; + py::object hook_; +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index abb26dfaa10c..6f0b22b73739 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -119,8 +119,19 @@ Reducer::Reducer( // This is used later on when the autograd graph is traversed // to check for parameters for which no gradient is computed, if // find_unused_parameters=True. + // We maintain a mapping of gradient accumulator to vector of variables, + // since multiple parameters may share the same grad accumulator. if (find_unused_parameters_) { - gradAccToVariableMap_[grad_accumulator.get()] = index; + auto gradAcc = gradAccToVariablesMap_.find(grad_accumulator.get()); + if (gradAcc == gradAccToVariablesMap_.end()) { + std::vector indexVec{index}; + gradAccToVariablesMap_[grad_accumulator.get()] = + std::move(indexVec); + } else { + // Scenario where we have indices whose corresponding parameters + // share the same grad accumulator. + gradAcc->second.push_back(index); + } } // The gradient accumulator is stored as weak_ptr in the autograd @@ -997,14 +1008,15 @@ void Reducer::prepare_for_backward( } // Find accumulator functions that don't show up in this graph. - for (const auto& it : gradAccToVariableMap_) { + for (const auto& it : gradAccToVariablesMap_) { // If the accumulator function is present in the graph, we know // a gradient will be computed for the corresponding parameter. - if (seen.count(it.first) > 0) { - continue; + if (seen.count(it.first) == 0) { + auto& indices = it.second; + unused_parameters_.reserve(unused_parameters_.size() + indices.size()); + unused_parameters_.insert( + unused_parameters_.end(), indices.begin(), indices.end()); } - - unused_parameters_.push_back(it.second); } } diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index 25f81857d101..53b4ecc4f981 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -122,8 +122,8 @@ class Reducer { std::vector>> grad_accumulators_; - std::unordered_map - gradAccToVariableMap_; + std::unordered_map> + gradAccToVariablesMap_; std::vector>> hooks_; diff --git a/torch/csrc/jit/backends/backend_init.cpp b/torch/csrc/jit/backends/backend_init.cpp index d23a518d1ca3..596c19e6ba1b 100644 --- a/torch/csrc/jit/backends/backend_init.cpp +++ b/torch/csrc/jit/backends/backend_init.cpp @@ -2,11 +2,121 @@ #include #include #include +#include #include namespace torch { namespace jit { +// Get all types that are shared in the module hierarchy rooted at \p mod. +std::unordered_set getSharedModuleTypes(Module& mod) { + // Maintain a set of all TypePtrs. + std::unordered_set types; + // Maintain another set of TypePtrs that have been encountered more than once. + std::unordered_set duplicate_types; + + // Iterate over all modules in the hierarchy, including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + if (types.count(module_type) > 0) { + duplicate_types.insert(module_type); + } + + types.insert(module_type); + } + + return duplicate_types; +} + +// Selectively lower \p mod to a backend. \p to_backend +// is called to lower modules. \p modules_to_lower contains +// qualified names of submodules of \p mod that should be lowered. +void toBackendSelectiveImpl( + Module& mod, + const py::function& to_backend, + const std::vector& modules_to_lower, + const std::unordered_set& duplicate_types) { + // This map will be used later to remap types in ancestor module graphs for + // all lowered submodules. + std::unordered_map type_remap; + + // For each module that should be lowered: + for (const auto& module_to_lower : modules_to_lower) { + // Use QualifiedName to parse the qualified module names. + c10::QualifiedName qual_module_name(module_to_lower); + auto& atoms = qual_module_name.atoms(); + + // Search through the module hierarchy using the atoms of + // qual_module_name until current points to the module to + // be lowered and parent points to its parent. + Module current = mod; + Module parent; + + for (size_t i = 0, e = atoms.size(); i < e; ++i) { + IValue submodule = current.attr(atoms[i]); + if (submodule.isModule()) { + if (i == e - 1) { + parent = current; + } + current = submodule.toModule(); + } else { + std::stringstream err; + err << "Attribute named " << atoms[i] << " is not a Module"; + throw std::runtime_error(err.str()); + } + } + + // Check that the parent type is not shared and therefore can be edited. + if (duplicate_types.count(parent.type()) > 0) { + throw py::cast_error(c10::str( + "Selective lowering is only supported for module hierarchies with unique types for selected modules; ", + parent.type()->repr_str(), + " is shared")); + } + + // Call to_backend on the module that needs to be lowered. It needs to be + // wrapped before doing so because _to_jit_backend accepts wrapped modules. + // The result needs to be unwrapped in order to access its type below. + auto lowered_submodule = + py::cast(to_backend(py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(current)) + .attr("_c")); + + // Adjust the parent's type so that the type of the submodule matches + // the type of lowered_submodule. + auto parent_type = parent.type(); + + parent_type->unsafeChangeAttributeType( + atoms.back(), lowered_submodule.type()); + parent.setattr(atoms.back(), lowered_submodule._ivalue()); + + // Record the type mapping from old type -> lowered type. + type_remap[current.type()] = lowered_submodule.type(); + } + + // Having lowered all of the modules that needed to be lowered, remap types in + // all graphs in the hierarchy so that the graphs all use the new lowered + // type. + auto type_remap_fn = [&type_remap](TypePtr in) { + auto it = type_remap.find(in); + if (it == type_remap.end()) + return in; + return it->second; + }; + + // modules() iterates over all modules in the hierarchy including the root. + for (auto module : mod.modules()) { + auto module_type = module.type(); + for (auto& fn : module_type->methods()) { + auto method = module.get_method(fn->name()); + auto graph = method.graph(); + graph->remapTypes(type_remap_fn); + auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn); + fn->setSchema(new_schema); + } + } +} + void initJitBackendBindings(PyObject* module) { // Bind a function for lowering to each JIT backend. The name of the backend // must be the first argument. For example, to lower a Module to @@ -244,6 +354,32 @@ void initJitBackendBindings(PyObject* module) { py::cast(orig_module.attr("_c")), method_compile_spec)); }); + + m.def( + "_jit_to_backend_selective", + [=](py::handle orig_module, + const py::function& to_backend, + const std::vector& modules_to_lower) { + if (auto original_module = + as_module(py::cast(orig_module))) { + // Clone the Module to avoid editing types that are shared with + // Modules in other instances outside this hierarchy. + Module& mod = original_module.value(); + auto cloned_mod = mod.clone(); + // Get all shared module types. Type sharing is only a problem if the + // parent modules of the ones to lower are in this set. + auto shared_types = getSharedModuleTypes(cloned_mod); + toBackendSelectiveImpl( + cloned_mod, to_backend, modules_to_lower, shared_types); + // Wrap the result in a RecursiveScriptModule because that's what + // the caller passed in. + return py::module::import("torch.jit._recursive") + .attr("wrap_cpp_module")(cloned_mod); + } + + throw py::cast_error(c10::str( + "Object ", py::str(orig_module), " is not a ScriptModule")); + }); } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/convert_to_ssa.cpp b/torch/csrc/jit/frontend/convert_to_ssa.cpp index 10109aa55824..1dd61c260bd6 100644 --- a/torch/csrc/jit/frontend/convert_to_ssa.cpp +++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp @@ -5,19 +5,20 @@ #include #include #include -#include namespace torch { namespace jit { // At the beginning of the pass the Graph has already undergone type checking, // and writes or reads to a variable are emitted as Loads and Stores in the -// graph. a = 1 print(a) is represented as: -// -// %a.1 : int = prim::Constant[value=1]() -// prim::Store[name="a"](%a.1) -// %a : int = prim::Load[name="a"]() -// prim::Print(%a) +// graph. +// a = 1 +// print(a) +// is represented as: +// %a.1 : int = prim::Constant[value=1]() +// prim::Store[name="a"](%a.1) +// %a : int = prim::Load[name="a"]() +// prim::Print(%a) // // First, this pass recursively adds the Loads & Stores to control flow nodes // Then the graph is converted to SSA form. @@ -149,7 +150,7 @@ struct ControlFlowLoadStores { case prim::Loop: { addLoopLoadStores(n); } break; - case prim::Function: { + case prim::Closure: { for (auto b : n->blocks()) { addControlFlowLoadStores(b); } @@ -157,7 +158,7 @@ struct ControlFlowLoadStores { case prim::Store: { environment_stack->setVar(n->s(attr::name), n->input()->type()); } break; - case prim::LocalVariableScope: { + case prim::ListComprehensionScope: { addControlFlowLoadStores(n->blocks().at(0)); } break; } @@ -204,7 +205,7 @@ struct EraseLoadStores { n->output()->replaceAllUsesWith(var); n->destroy(); } break; - case prim::LocalVariableScope: { + case prim::ListComprehensionScope: { // writes within a local variable scope do not leak into // the rest of the graph auto body = n->blocks().at(0); @@ -279,7 +280,7 @@ struct LoopContinuations { assignExitContinuations(n->blocks().at(0)); assignExitContinuations(n->blocks().at(1)); } break; - case prim::Function: { + case prim::Closure: { LoopContinuations closure_block; closure_block.run(n->blocks().at(0)); } break; diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 3126d78c3bd2..e14cb6428890 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -119,7 +119,7 @@ struct ExitTransformer { static bool isGraphOrClosureBlock(Block* block) { return block->owningNode() == nullptr || - owningNodeKind(block) == prim::Function; + owningNodeKind(block) == prim::Closure; } static void removeOutputs(Block* b) { @@ -425,7 +425,7 @@ struct ExitTransformer { case prim::With: { exit_pair = transformWith(node); } break; - case prim::Function: { + case prim::Closure: { // exits of closure declaration stay local to the closure transformExits(node->blocks().at(0)); } break; diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 5294e02e739f..a4b239418cfb 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -859,9 +860,12 @@ struct to_ir { return emitStatements(statements.begin(), statements.end()); } - // XXX - right now closures are used _only_ for defining gradients internally + // XXX: Right now closures are not generically implemented and are only used + // as an intermediate form for special tasks, like defining gradients or + // forked functions. + // // There are several unfinished aspects that make them unusable generally - // 1. We do not have a type, ivalue, operator to represent prim::Function, so + // 1. We do not have a type, ivalue, operator to represent prim::Closure, so // closure_node has type None // 2. There is no export logic for it yet, so it cannot be // exported/python_printed @@ -870,9 +874,19 @@ struct to_ir { // the changes to those variables will just get forgotten. // 4. There is no parsing support in frontend.py, this is intentional since it // prevents people from accidentally using this feature. + // + // This function leaves in the graph something like: + // + // %2 : None = prim::Closure() + // block0(): + // %1 : Tensor = prim::DoSomething(%0) + // -> (%1) + // + // A separate pass is required to erase this closure and replace it with + // something actually executable (see liftClosure and inlineForkedClosure). std::shared_ptr emitClosure( const std::function& emit_body) { - Node* closure_node = graph->insertNode(graph->create(prim::Function, 1)); + Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1)); // it is not a real thing yet, so just say the type is None closure_node->output()->setType(NoneType::get()); Block* block = closure_node->addBlock(); @@ -1262,7 +1276,7 @@ struct to_ir { // comprehension introduces it's own scope. no variable assigned // leaks into the rest of the graph Node* n = - graph->insertNode(create(prim::LocalVariableScope, lc.range(), 0)); + graph->insertNode(create(prim::ListComprehensionScope, lc.range(), 0)); auto* comprehension_block = n->addBlock(); pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); @@ -2094,8 +2108,8 @@ struct to_ir { stmt.range(), *method.graph(), getAugOp(stmt, lhs->type()), - /*inputs=*/{lhs, rhs}, - /*attributes=*/{}, + /*args=*/{lhs, rhs}, + /*kwargs=*/{}, /*self=*/c10::nullopt); } } @@ -2665,9 +2679,9 @@ struct to_ir { if (auto special_form = dynamic_cast(sv.get())) { return emitApplySpecialForm(special_form->form(), apply, type_hint); } - auto inputs = getNamedValues(apply.inputs(), true); - auto attributes = emitAttributes(apply.attributes()); - return sv->call(loc, method, inputs, attributes, n_binders); + auto args = getNamedValues(apply.inputs(), true); + auto kwargs = emitAttributes(apply.attributes()); + return sv->call(loc, method, args, kwargs, n_binders); } // this function handles expressions that look like apply statements @@ -2688,9 +2702,9 @@ struct to_ir { } auto forked = emitSugaredExpr(Expr(trees[0]), 1); TreeList sliced_trees(trees.begin() + 1, trees.end()); - auto inputs = getNamedValues(sliced_trees, true); - auto attributes = emitAttributes(apply.attributes()); - return emitForkExpr(apply.range(), forked, inputs, attributes); + auto args = getNamedValues(sliced_trees, true); + auto kwargs = emitAttributes(apply.attributes()); + return emitForkExpr(apply.range(), forked, args, kwargs); } case prim::annotate: { checkApplyNumInputs(apply, 2); @@ -2932,7 +2946,7 @@ struct to_ir { return emitApplyExpr(apply, n_binders, type_hint); } break; case TK_SUBSCRIPT: { - return emitSubscript(Subscript(tree)); + return emitSubscript(Subscript(tree), type_hint); } break; default: return std::make_shared(emitSimpleExpr(tree, type_hint)); @@ -2965,11 +2979,15 @@ struct to_ir { return graph->insertConstant(maybe_out_stack->at(0), tree->range()); } + /** + * Emit a fork expression, of the form: + * torch.jit.fork(forked, *args, **kwargs) + */ std::shared_ptr emitForkExpr( SourceRange loc, const std::shared_ptr& forked, - at::ArrayRef inputs, - at::ArrayRef attributes) { + at::ArrayRef args, + at::ArrayRef kwargs) { auto g = method.graph(); Node* fork_node; TypePtr out_type; @@ -2989,8 +3007,7 @@ struct to_ir { fork_node->addInput(closure_output); } else { auto emit_closure_body = [&](Block* closure_block) { - auto fn_sugared_output = - forked->call(loc, method, inputs, attributes, 1); + auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1); auto fn_simple_output = fn_sugared_output->asValue(loc, method); closure_block->registerOutput(fn_simple_output); out_type = fn_simple_output->type(); @@ -3788,7 +3805,9 @@ struct to_ir { ->output(); } - std::shared_ptr emitSubscript(const Subscript& subscript) { + std::shared_ptr emitSubscript( + const Subscript& subscript, + TypePtr type_hint = nullptr) { const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); const List& subscript_exprs = subscript.subscript_exprs(); const SourceRange& range = subscript.range(); @@ -3858,7 +3877,7 @@ struct to_ir { return std::make_shared( emitMultidimSlicing(range, sliceable, subscript_exprs)); } else { - return sv->getitem(range, method, idx); + return sv->getitem(range, method, idx, std::move(type_hint)); } } } diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index fb2e0f20f380..9fd3973f9b3d 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -584,8 +584,8 @@ Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self) { const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); @@ -620,7 +620,7 @@ Value* emitBuiltinCall( throw error; } - auto matched = matchSchemas(schemas, loc, graph, inputs, attributes, self); + auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self); if (matched.first < variants.size()) { return emitBuiltinNode(matched.second, loc, graph, name); diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 88fe23a9682d..83e34bb33ae5 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -23,7 +23,7 @@ TORCH_API MatchedSchema matchSchema( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API std::pair matchSchemas( @@ -31,7 +31,7 @@ TORCH_API std::pair matchSchemas( const SourceRange& loc, Graph& graph, at::ArrayRef args, - at::ArrayRef kwarg, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt, bool render_errors = false); @@ -43,8 +43,8 @@ TORCH_API Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, const c10::optional& self = c10::nullopt); TORCH_API c10::optional findInputWithName( diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 69e86716f72e..8810a5a62019 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -17,14 +18,14 @@ struct NoneValue : SugaredValue { std::shared_ptr PrintValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); - if (!attributes.empty()) + if (!kwargs.empty()) throw ErrorReport(loc) << "print doesn't accept any keyword arguments"; - std::vector lowered_inputs = toValues(*m.graph(), inputs); + std::vector lowered_inputs = toValues(*m.graph(), args); g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc)); return std::make_shared(); } @@ -46,11 +47,11 @@ builtin_cast_method_to_scalar_type() { std::shared_ptr BuiltinFunction::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { return std::make_shared( - emitBuiltinCall(loc, *m.graph(), symbol, inputs, attributes, self)); + emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self)); } // older versions of gcc/clang have a bug where enums can't be used as keys @@ -322,14 +323,14 @@ void SimpleValue::setAttr( std::shared_ptr SimpleValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // allow our 'fake' closures to be called, used for fork serialization // at the moment, but can be expanded later Node* self = getValue()->node(); if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 && - self->inputs().at(0)->node()->kind() == prim::Function) { + self->inputs().at(0)->node()->kind() == prim::Closure) { std::shared_ptr graph = self->inputs().at(0)->node()->g(attr::Subgraph); Value* context = self->inputs().at(1); @@ -348,16 +349,15 @@ std::shared_ptr SimpleValue::call( auto ret = StrongFunctionPtr(std::move(cu), fn); std::vector ctx_inputs = {close_context}; - ctx_inputs.insert(ctx_inputs.end(), inputs.begin(), inputs.end()); - return FunctionValue(ret).call(loc, m, ctx_inputs, attributes, n_binders); + ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end()); + return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders); } if (auto class_type = getValue()->type()->cast()) { - return attr(loc, m, "__call__") - ->call(loc, m, inputs, attributes, n_binders); + return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders); } - return SugaredValue::call(loc, m, inputs, attributes, n_binders); + return SugaredValue::call(loc, m, args, kwargs, n_binders); } Value* SimpleValue::len(const SourceRange& loc, Function& m) { @@ -377,7 +377,8 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr SimpleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { Value* val = getValue(); TypePtr val_type = val->type(); Graph& g = *m.graph(); @@ -393,6 +394,17 @@ SugaredValuePtr SimpleValue::getitem( return std::make_shared( g.insert(aten::select, {val, 0, idx}, {}, loc)); } else if (auto class_type = val_type->cast()) { + // Check if this is an indexing operation enabled by a type hint. + // The ModuleDict has already been checked during IR generation to make + // sure its contents implement the module interface referred to by + // type_hint. + if (class_type->is_module() && type_hint) { + auto res = g.insert(prim::ModuleDictIndex, {val, idx}, {}, loc); + res->setType(type_hint); + return std::make_shared(res); + } + + // Defer to the __getitem__ attr on the class. return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1); } else { throw ErrorReport(loc) << "'" << val_type->repr_str() << "'" @@ -485,7 +497,8 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) { SugaredValuePtr RangeValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (has_only_end_) { return std::make_shared(idx); } else { @@ -535,7 +548,8 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) { SugaredValuePtr IterableTree::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { std::vector child_items; for (const SugaredValuePtr& child : children_) { child_items.emplace_back(child->getitem(loc, m, idx)); @@ -569,27 +583,27 @@ void IterableTree::addChild( std::shared_ptr MagicMethod::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - if (inputs.size() > 0) { - Value* self = inputs[0].value(*m.graph()); + if (args.size() > 0) { + Value* self = args[0].value(*m.graph()); if (auto class_ptr = self->type()->cast()) { return SimpleValue(self) .attr(loc, m, desugared_name_) - ->call(loc, m, inputs.slice(1), attributes, n_binders); + ->call(loc, m, args.slice(1), kwargs, n_binders); } } TORCH_INTERNAL_ASSERT(base_value_); - return base_value_->call(loc, m, inputs, attributes, n_binders); + return base_value_->call(loc, m, args, kwargs, n_binders); } std::shared_ptr ClassValue::call( const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { AT_ASSERT(n_binders <= 1); @@ -602,7 +616,7 @@ std::shared_ptr ClassValue::call( } // Call the init function - MethodValue(self, "__init__").call(loc, m, inputs, attributes, n_binders); + MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders); return std::make_shared(self); } @@ -621,15 +635,15 @@ std::shared_ptr ClassValue::attr( std::shared_ptr NamedTupleConstructor::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { auto& g = *m.graph(); auto schema = type_->schema(); TORCH_INTERNAL_ASSERT(schema); auto qualname = type_->name(); - auto matched_schema = matchSchema(*schema, loc, g, inputs, attributes); + auto matched_schema = matchSchema(*schema, loc, g, args, kwargs); auto self = g.insertNode( diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 3523523f5c23..28a18aceda49 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -84,8 +84,8 @@ struct TORCH_API SugaredValue const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { // n_binders is always set to the number of variables an expression is // syntactically bound to: @@ -139,7 +139,8 @@ struct TORCH_API SugaredValue virtual std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint = nullptr) { throw ErrorReport(loc) << "'" << kind() << "'" << " object is not subscriptable"; } @@ -181,8 +182,8 @@ struct TORCH_API SimpleValue : public SugaredValue { const SourceRange& loc, Function& m, // note: names for args will be 'argument 0', 'argument 1', etc.. - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr iter(const SourceRange& loc, Function& m) @@ -193,8 +194,11 @@ struct TORCH_API SimpleValue : public SugaredValue { } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: Value* value_; @@ -215,8 +219,8 @@ struct TORCH_API BuiltinFunction : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef attributes, - at::ArrayRef inputs, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; // try to create this builtin but if it doesn't exist or the self argument @@ -251,8 +255,11 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { return "Tuple"; } - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override { + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override { if (!(idx->type()->cast() && toIValue(idx))) { throw ErrorReport(loc) << "Expected integer literal for index. " @@ -332,8 +339,8 @@ struct TORCH_API ClassValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::shared_ptr attr( @@ -354,8 +361,8 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override { @@ -384,8 +391,8 @@ struct FunctionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { std::vector schemas; for (Function* callee : callees_) { @@ -398,7 +405,7 @@ struct FunctionValue : public SugaredValue { } schemas.push_back(&callee->getSchema()); } - auto match = matchSchemas(schemas, loc, *f.graph(), inputs, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs); Value* output = f.graph()->insertFunctionCall(callees_[match.first], match.second); output->node()->setSourceRange(loc); @@ -417,7 +424,7 @@ struct FunctionValue : public SugaredValue { struct TORCH_API ClosureValue : public SugaredValue { ClosureValue(Value* value) : value_(value) { - TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Function); + TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure); } std::string kind() const override { return "closure"; @@ -442,11 +449,11 @@ struct MethodValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - std::vector inputsWithSelf = {self_}; - inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end()); + std::vector argsWithSelf = {self_}; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); std::vector schemas; for (const std::string& method_name : method_names_) { if (auto class_type = self_->type()->cast()) { @@ -466,8 +473,7 @@ struct MethodValue : public SugaredValue { false, "method constructed that is not a class or interface"); } } - auto match = - matchSchemas(schemas, loc, *f.graph(), inputsWithSelf, attributes); + auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs); Value* output = f.graph()->insertMethodCall(method_names_[match.first], match.second); output->node()->setSourceRange(loc); @@ -486,8 +492,8 @@ struct TORCH_API PrintValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; @@ -500,16 +506,16 @@ struct TORCH_API CastValue : public BuiltinFunction { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() == 1 && attributes.size() == 0) { - auto v = inputs[0].value(*m.graph()); + if (args.size() == 1 && kwargs.size() == 0) { + auto v = args[0].value(*m.graph()); if (v->type()->isSubtypeOf(type_)) { return std::make_shared(v); } } - return BuiltinFunction::call(loc, m, inputs, attributes, n_binders); + return BuiltinFunction::call(loc, m, args, kwargs, n_binders); } private: @@ -527,17 +533,17 @@ struct TORCH_API TensorCastValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - TORCH_INTERNAL_ASSERT(inputs.size() == 0 && attributes.size() == 0); + TORCH_INTERNAL_ASSERT(args.size() == 0 && kwargs.size() == 0); Value* dtype_const = m.graph()->insertConstant(dtype_, loc); - std::vector kwargs{self_, - NamedValue(loc, "dtype", dtype_const)}; + std::vector kwargs_{self_, + NamedValue(loc, "dtype", dtype_const)}; Value* casted_val = m.graph()->insert( /*opname=*/Symbol::fromQualString("aten::to"), - /*args=*/inputs, - /*kwargs=*/kwargs, + /*args=*/args, + /*kwargs=*/kwargs_, /*range=*/loc); return std::make_shared(casted_val); } @@ -560,8 +566,8 @@ struct TORCH_API MagicMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -604,8 +610,11 @@ struct TORCH_API RangeValue : SugaredValue { return "range"; } Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; std::shared_ptr iter(const SourceRange& loc, Function& m) override; @@ -680,8 +689,11 @@ struct TORCH_API IterableTree : SugaredValue { std::vector get_base_iterables(); Value* len(const SourceRange& loc, Function& m) override; - SugaredValuePtr getitem(const SourceRange& loc, Function& m, Value* idx) - override; + SugaredValuePtr getitem( + const SourceRange& loc, + Function& m, + Value* idx, + TypePtr type_hint = nullptr) override; private: c10::optional unroll_length_ = c10::nullopt; @@ -735,11 +747,11 @@ struct TORCH_API ExceptionValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs, + at::ArrayRef args, at::ArrayRef /*attributes*/, size_t /*n_binders*/) override { auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc); - for (auto& input : inputs) { + for (auto& input : args) { auto input_str = input.value(*m.graph()); if (!input_str->type()->isSubtypeOf(StringType::get())) { input_str = diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index bb5872f35f4f..b055d29164a5 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -494,7 +494,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::MMBatchSide: case prim::BroadcastSizes: case prim::ChunkSizes: - case prim::Function: + case prim::Closure: case prim::CreateObject: case prim::tolist: return analyzeCreator(node); diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 00a2005df8d5..22fc369f38e0 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -67,6 +67,16 @@ class TORCH_API Module { return metadata_; } + c10::IValue attr(const std::string& name, c10::IValue or_else) const { + if (auto r = object_->type()->findAttributeSlot(name)) { + return object_->getSlot(*r); + } + if (auto r = object_->type()->findConstantSlot(name)) { + return object_->type()->getConstant(*r); + } + return or_else; + } + private: c10::intrusive_ptr object_; std::unordered_map metadata_; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index c3285f2e2426..2981daa0006a 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -89,7 +89,7 @@ namespace { std::unordered_set skip_list = { prim::If, prim::Loop, - prim::Function, + prim::Closure, prim::Constant, prim::AutogradZero, prim::Uninitialized, diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 6b5beb4372a8..76b6f1d234ba 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -212,6 +212,13 @@ class AttributePropagator { for (Block* sub_block : n->blocks()) { blocks.push(sub_block); } + + // Modules with prim::ModuleDictIndex cannot be frozen because they + // return InterfaceTypes. + TORCH_CHECK( + n->kind() != prim::ModuleDictIndex, + "Freezing modules containing prim::ModuleDictIndex is not supported"); + if (n->kind() == prim::SetAttr || n->kind() == prim::GetAttr) { // By default if interface attributes are present then fail freezing. // If freezingInterfaces is on then Interfaces are folded similarly diff --git a/torch/csrc/jit/passes/inline_forked_closures.cpp b/torch/csrc/jit/passes/inline_forked_closures.cpp index ea5a977e4091..e97d71e32249 100644 --- a/torch/csrc/jit/passes/inline_forked_closures.cpp +++ b/torch/csrc/jit/passes/inline_forked_closures.cpp @@ -19,7 +19,7 @@ void inlineForkedClosure(Node* fork_closure) { Node* function_context_node = fork_closure->input()->node(); if (function_context_node->inputs().size() != 2 || - function_context_node->inputs().at(0)->node()->kind() != prim::Function || + function_context_node->inputs().at(0)->node()->kind() != prim::Closure || function_context_node->inputs().at(1)->node()->kind() != prim::TupleConstruct) { throw ErrorReport(fork_closure->sourceRange()) << "Cannot fork this value"; diff --git a/torch/csrc/jit/passes/lift_closures.cpp b/torch/csrc/jit/passes/lift_closures.cpp index 82e6f2216681..4f5941ce8afb 100644 --- a/torch/csrc/jit/passes/lift_closures.cpp +++ b/torch/csrc/jit/passes/lift_closures.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { -// Closures are initially emitted as prim::Function nodes with a single block. +// Closures are initially emitted as prim::Closure nodes with a single block. // Here, we convert the block to a subgraph, adding all closed over variables // as a context tuple input to the closure node. // At this point the closure has already undergone conversion to SSA, @@ -58,7 +58,7 @@ void liftClosures(Block* block) { Node* n = *it; it++; switch (n->kind()) { - case prim::Function: { + case prim::Closure: { liftClosure(n); } break; default: { diff --git a/torch/csrc/jit/passes/metal_rewrite.cpp b/torch/csrc/jit/passes/metal_rewrite.cpp index b724aedd698c..ffc00001c797 100644 --- a/torch/csrc/jit/passes/metal_rewrite.cpp +++ b/torch/csrc/jit/passes/metal_rewrite.cpp @@ -15,8 +15,6 @@ namespace torch { namespace jit { -#ifdef USE_PYTORCH_METAL - namespace { void insertPrePackedConv2dOp(std::shared_ptr& graph) { @@ -160,16 +158,14 @@ void metalFusePrePackedConvWithClamp(script::Module& module) { void metalInsertCopyOps(script::Module& module) { auto graph = module.get_method("forward").graph(); auto&& outputs = graph->outputs(); - for (int i = 0; i < outputs.size(); ++i) { + for (size_t i = 0; i < outputs.size(); ++i) { Value* output = outputs[i]; - std::cout << "find output: " << *output->node() << std::endl; auto namedValue = NamedValue("", output); if (namedValue.type()->kind() == TypeKind::TensorType) { // find the insertion point WithInsertPoint ip(output->node()->next()); Value* replaced_output = graph->insert( Symbol::fromQualString("metal::copy_to_host"), {namedValue}); - std::cout << "insert: " << *replaced_output->node() << std::endl; // replaced the output graph->block()->replaceOutput(i, replaced_output); } @@ -190,40 +186,10 @@ script::Module metalOptimizeForMobile( metalFoldPrePackingOps(cloned_module); metalInsertCopyOps(cloned_module); removeDropout(cloned_module); + cloned_module.register_attribute( + "optimized_for_metal", BoolType::get(), true); return cloned_module; } -#else - -void metalInsertPrePackedOps(std::shared_ptr& graph) { - TORCH_INTERNAL_ASSERT( - "metal is not enabled. Please build with USE_PYTORCH_METAL=1"); -} - -void metalInsertPrePackedOps(script::Module& module) { - TORCH_INTERNAL_ASSERT( - "metal is not enabled. Please build with USE_PYTORCH_METAL=1"); -} - -TORCH_API void metalFusePrePackedConvWithClamp(script::Module& module) { - TORCH_INTERNAL_ASSERT( - "metal is not enabled. Please build with USE_PYTORCH_METAL=1"); -} - -TORCH_API void metalFoldPrePackingOps(script::Module& module) { - TORCH_INTERNAL_ASSERT( - "metal is not enabled. Please build with USE_PYTORCH_METAL=1"); -} - -script::Module metalOptimizeForMobile( - const script::Module& m, - const std::vector& preserved_methods) { - TORCH_INTERNAL_ASSERT( - "Mobile optimizaiton only available with metal at the moment. " - "metal is not enabled. Please build with USE_PYTORCH_METAL=1"); - return m; -} - -#endif } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index 9d9ac0203b90..a06a3f94f3b1 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -78,6 +78,7 @@ const std::unordered_map& getOperatorAliasMap() { {aten::divide, aten::div}, {aten::divide_, aten::div_}, {aten::multiply, aten::mul}, {aten::multiply_, aten::mul_}, {aten::true_divide, aten::div}, {aten::true_divide_, aten::div_}, + {aten::row_stack, aten::vstack}, }; return alias_map; } diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index 41cd3cd2b8af..b854ae14387a 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -21,6 +21,7 @@ static constexpr char TupleOpen = '('; static constexpr char TupleClose = ')'; static constexpr char Variable = 'v'; static constexpr char String = 's'; +static constexpr char NoneType = 'n'; } // namespace D namespace { @@ -62,6 +63,8 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) { args.vars.push_back(var); args.desc.metadata.emplace_back(var); args.desc.structure.push_back(D::Variable); + } else if (strcmp(THPUtils_typename(obj), "NoneType") == 0) { + args.desc.structure.push_back(D::NoneType); } else { std::string msg = "Only tuples, lists and Variables supported as JIT inputs/outputs. " @@ -136,6 +139,8 @@ py::object unflatten_rec( throw std::runtime_error("Not enough Variables given to unflatten"); auto str = *str_it++; return py::reinterpret_borrow(THPUtils_packString(str)); + } else if (type == D::NoneType) { + return py::reinterpret_borrow(py::none()); } else { if (var_it == var_it_end) throw std::runtime_error("Not enough Variables given to unflatten"); diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index fc968237e4ba..20d0e6272a19 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -493,6 +493,7 @@ void initPythonIRBindings(PyObject* module_) { }) .def("sourceRange", [](Node& n) { return n.sourceRange().str(); }) .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; }) + .def("inputsSize", [](Node& n) { return n.inputs().size(); }) .def("outputsSize", [](Node& n) { return n.outputs().size(); }) .NS(kind) .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); }) diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 15d151b761ef..7da82644150e 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1,4 +1,5 @@ #include + #include #include #include @@ -114,21 +115,20 @@ FunctionSchema PythonValue::getSchema( std::shared_ptr PythonValue::call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { - std::vector inputsWithSelf; + std::vector argsWithSelf; if (moduleSelf_) { - inputsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); + argsWithSelf.emplace_back(NamedValue("self", moduleSelf_)); } - inputsWithSelf.insert(inputsWithSelf.end(), inputs_.begin(), inputs_.end()); - inputs_ = inputsWithSelf; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); - auto schema = getSchema(inputs_.size(), n_binders, loc); - auto inputs = toValues(*m.graph(), inputs_); + auto schema = getSchema(argsWithSelf.size(), n_binders, loc); + auto inputs = toValues(*m.graph(), argsWithSelf); MatchedSchema matched_schema = - matchSchema(schema, loc, *m.graph(), inputs_, attributes); + matchSchema(schema, loc, *m.graph(), argsWithSelf, kwargs); // If if a function is marked as dropped, // we throw an exception if it is invoked. @@ -234,9 +234,11 @@ SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) { SugaredValuePtr ModuleValue::getitem( const SourceRange& loc, Function& m, - Value* idx) { + Value* idx, + TypePtr type_hint) { if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) { - return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx); + return getSugaredDict(loc, m)->getModules()->getitem( + loc, m, idx, type_hint); } else if ( concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) { if (auto ivalue = toIValue(idx)) { @@ -252,6 +254,30 @@ SugaredValuePtr ModuleValue::getitem( } } throw ErrorReport(loc) << "Key Error, " << idx_str; + } else if (type_hint) { + // Check that all submodules comply with the type hint. + const auto& self_type = concreteType_->getJitType()->expect(); + for (size_t i = 0; i < self_type->numAttributes(); ++i) { + const auto& attr_type = self_type->getAttribute(i); + if (attr_type->is_module()) { + if (!attr_type->isSubtypeOf(type_hint)) { + auto loc = self_->node()->sourceRange(); + throw ErrorReport(loc) + << "Attribute " << self_type->getAttributeName(i) + << " is not of annotated type " << type_hint->annotation_str(); + } + } + } + + // Emit a prim::ModuleDictIndex operator. This is needed because it's + // difficult to construct a dict in the graph representing the ModuleDict + // and use aten::__getitem__ ops to index into it because any call to + // ModuleDict.setAttr would invalidate that emitted dict. + auto graph = m.graph(); + auto* getitem_node = + graph->insertNode(graph->create(prim::ModuleDictIndex, {self_, idx})); + getitem_node->output(0)->setType(type_hint); + return std::make_shared(getitem_node->output(0)); } throw ErrorReport(loc) << "Unable to extract string literal index. " @@ -652,8 +678,8 @@ void ModuleValue::setAttr( std::shared_ptr BooleanDispatchValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) { c10::optional result; Graph& graph = *(caller.graph()); @@ -662,14 +688,14 @@ std::shared_ptr BooleanDispatchValue::call( auto arg_name = py::str(dispatched_fn_["arg_name"]); ErrorReport error(loc); - if (index < inputs.size()) { + if (index < args.size()) { // Dispatch flag is in arg list - result = constant_as(inputs.at(index).value(graph)); + result = constant_as(args.at(index).value(graph)); error << "Argument for boolean dispatch at position " << index << " was not constant"; - } else if (auto i = findInputWithName(arg_name, attributes)) { + } else if (auto i = findInputWithName(arg_name, kwargs)) { // Dispatch flag is in kwargs - result = constant_as(attributes[*i].value(graph)); + result = constant_as(kwargs[*i].value(graph)); error << "Keyword argument '" << arg_name << "' for boolean dispatch at position was not constant"; } else { @@ -688,28 +714,28 @@ std::shared_ptr BooleanDispatchValue::call( } else { value = toSugaredValue(dispatched_fn_["if_false"], caller, loc); } - return value->call(loc, caller, inputs, attributes, n_binders); + return value->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr PythonExceptionValue::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { Value* error_message = nullptr; - if (inputs.size() == 0) { + if (args.size() == 0) { error_message = insertConstant(*caller.graph(), "", loc); - } else if (inputs.size() == 1) { - error_message = inputs.at(0).value(*caller.graph()); + } else if (args.size() == 1) { + error_message = args.at(0).value(*caller.graph()); } else { std::vector message_values; - message_values.reserve(inputs.size() + attributes.size()); + message_values.reserve(args.size() + kwargs.size()); - for (auto inp : inputs) { + for (const auto& inp : args) { message_values.push_back(inp.value(*caller.graph())); } - for (auto kwarg_inp : attributes) { + for (const auto& kwarg_inp : kwargs) { message_values.push_back(kwarg_inp.value(*caller.graph())); } error_message = @@ -802,10 +828,10 @@ std::shared_ptr createSimpleEnumValue( std::shared_ptr PythonSliceClass::call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t /*n_binders*/) { - if (!attributes.empty()) { + if (!kwargs.empty()) { throw ErrorReport(loc) << "Slice does not accept any keyword arguments"; } @@ -824,23 +850,23 @@ std::shared_ptr PythonSliceClass::call( Value* start; Value* stop; Value* step; - size_t n = inputs.size(); + size_t n = args.size(); // Slice's constructor signature is Slice(start=None, stop, step=None) if (n == 1) { // Case where only `stop` is specified. start = ValOr(nullptr, default_start); - stop = ValOr(inputs[0].value(graph), default_stop); + stop = ValOr(args[0].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 2) { // Case where `start` and `stop` are specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); step = ValOr(nullptr, default_step); } else if (n == 3) { // Case where `start`, `stop` and `step` are all specified. - start = ValOr(inputs[0].value(graph), default_start); - stop = ValOr(inputs[1].value(graph), default_stop); - step = ValOr(inputs[2].value(graph), default_step); + start = ValOr(args[0].value(graph), default_start); + stop = ValOr(args[1].value(graph), default_stop); + step = ValOr(args[2].value(graph), default_step); } else { throw ErrorReport(loc) << "slice accepts exactly 1, 2 or 3 arguments, got: " << n; diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index ecb3c6da4ff4..12a5d87b063e 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -47,8 +47,8 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& m, - at::ArrayRef inputs_, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; std::string kind() const override; @@ -99,8 +99,8 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return toSimple(the_list_); } @@ -120,10 +120,10 @@ struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& f, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { - if (inputs.size() || attributes.size()) { + if (args.size() || kwargs.size()) { throw ErrorReport(loc) << name_ << " method does not accept any arguments"; } @@ -175,11 +175,11 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override { return attr(loc, caller, "forward") - ->call(loc, caller, inputs, attributes, n_binders); + ->call(loc, caller, args, kwargs, n_binders); } std::shared_ptr getSugaredDict( @@ -201,7 +201,8 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { std::shared_ptr getitem( const SourceRange& loc, Function& m, - Value* idx) override; + Value* idx, + TypePtr type_hint) override; private: Value* self_; @@ -268,8 +269,8 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; private: @@ -308,8 +309,8 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; @@ -324,8 +325,8 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue { std::shared_ptr call( const SourceRange& loc, Function& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, + at::ArrayRef args, + at::ArrayRef kwargs, size_t n_binders) override; }; diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index a99f7469ac65..7b9332586397 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -714,7 +714,8 @@ void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); // NOLINTNEXTLINE(bugprone-unused-raii) - py::class_>(m, "Capsule"); + py::class_>( + m, "Capsule"); auto object_class = py::class_(m, "ScriptObject") diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index e36208dfb19f..dc1ff95cf735 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -287,7 +287,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::MMBatchSide, prim::BroadcastSizes, prim::ChunkSizes, - prim::Function, + prim::Closure, prim::TupleUnpack, prim::TupleIndex, prim::TupleSlice, diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 3d7e8bb6f60e..d62697c1f9b6 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -40,8 +40,7 @@ C10_DEFINE_bool( namespace torch { namespace jit { -// TODO: keep the else clause for trial runs -#if defined(FBCODE_CAFFE2) || defined(C10_MOBILE) +#if defined(C10_MOBILE) static std::atomic executor_mode{true}; static std::atomic profiling_mode{false}; #else diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index ae974c063ef3..5bd85d20556a 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index f7dd9594347e..4706635a6a0c 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -568,6 +568,18 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + // This operator is generated inside the compiler for indexing into + // ModuleDict without a statically determinable key. Accordingly, + // self must be a ModuleType and the output must be an InterfaceType. + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "prim::ModuleDictIndex(Any self, str ind) -> Any"), + [](Stack* stack) { + IValue ind = pop(stack); + IValue module_dict = pop(stack); + push(stack, module_dict.toModule().attr(ind.toStringRef())); + }, + aliasAnalysisFromSchema()), Operator( "aten::dict() -> Dict(str, Tensor)", [](Stack* stack) { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 578586e9e9ff..6887be516e7b 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1369,8 +1369,8 @@ std::pair, Value*> extractClosure(Value* closure) { Value* context = closure->node()->inputs().at(1); TORCH_CHECK( - fn->node()->kind() == prim::Function, - "closure tuple must contain a prim::Function"); + fn->node()->kind() == prim::Closure, + "closure tuple must contain a prim::Closure"); return std::make_pair(fn->node()->g(attr::Subgraph), context); } diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index e04339dacc22..9803829eb683 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -801,7 +801,7 @@ struct PythonPrintImpl { } level--; } break; - case prim::Function: { + case prim::Closure: { if (enforce_importable_) { throw ErrorReport(node->sourceRange()) << "closures are not exportable"; @@ -822,6 +822,15 @@ struct PythonPrintImpl { body_ << "):\n"; printBody(graph->block()); } break; + case prim::ModuleDictIndex: { + const auto dict = node->inputs().at(0); + const auto key = node->inputs().at(1); + const auto out = node->outputs().at(0); + assignValuesToTheirUniqueNames(out); + indent(); + body_ << useOf(out) << " : " << out->type()->annotation_str() << " = " + << useOf(dict) << "[" << useOf(key) << "]\n"; + } break; default: auto ss = std::make_shared(&source_range_stack_); printRHS(*ss, node); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 68ec5c2e304a..4705760b868d 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1366,7 +1366,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { // Compute non-output tensors_ inline for (auto& p : tensors_) { - if (!l.hasLoopBodyFor(p.second) || hasReduction) { + if (!l.hasLoopBodyFor(p.second)) { continue; } l.computeInline(p.second->buf()); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 3873cdd0ebf0..0400b6f14143 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -11,12 +11,17 @@ #include #include #include +#include #include #include #include #include #include +#if LLVM_VERSION_MAJOR >= 11 +#include +#endif + #include #include #include @@ -73,6 +78,22 @@ llvm::CmpInst::Predicate llvm_comparison_predicate( } } +#if LLVM_VERSION_MAJOR <= 9 +int ElementCount(int lanes) { + return lanes; +} +#else +llvm::ElementCount ElementCount(int lanes) { +#if LLVM_VERSION_MAJOR <= 11 + return llvm::ElementCount(static_cast(lanes), false); +#elif LLVM_VERSION_MAJOR == 12 + return llvm::ElementCount(llvm::PolySize::getFixed(lanes)); +#else +#error Only LLVM versions 8 through 12 are supported. +#endif +} +#endif + } // namespace class LLVMCodeGenImpl : public IRVisitor { @@ -188,7 +209,7 @@ static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() { } JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default); - JTMB.setCPU(llvm::sys::getHostCPUName()); + JTMB.setCPU(llvm::sys::getHostCPUName().str()); JTMB.addFeatures(SubtargetFeatures.getFeatures()); JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast; @@ -795,7 +816,7 @@ void LLVMCodeGenImpl::visit(const Cast* v) { llvm::Type* dstType = dtypeToLLVM(v->dtype()); if (v->dtype().lanes() > 1) { - dstType = llvm::VectorType::get(dstType, v->dtype().lanes()); + dstType = llvm::VectorType::get(dstType, ElementCount(v->dtype().lanes())); } llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); @@ -866,10 +887,11 @@ void LLVMCodeGenImpl::visit(const Ramp* v) { } llvm::Type* vecType = nullptr; + auto element_count = ElementCount(lanes); switch (v->dtype().scalar_type()) { -#define TYPE_CASE(_1, Name) \ - case ScalarType::Name: \ - vecType = llvm::VectorType::get(Name##Ty_, lanes); \ +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + vecType = llvm::VectorType::get(Name##Ty_, element_count); \ break; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -939,10 +961,11 @@ void LLVMCodeGenImpl::visit(const Load* v) { llvm::Type* loadType = nullptr; + auto element_count = ElementCount(v->dtype().lanes()); switch (v->dtype().scalar_type()) { -#define TYPE_CASE(_1, Name) \ - case ScalarType::Name: \ - loadType = llvm::VectorType::get(Name##Ty_, v->dtype().lanes()); \ +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + loadType = llvm::VectorType::get(Name##Ty_, element_count); \ break; AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); #undef TYPE_CASE @@ -1225,14 +1248,14 @@ static void applyMathFunctionAttributes(llvm::Function* f) { // TODO: Adding this attr should be correct, but as of LLVM 9.0.1 adding it // causes some math functions to incorrectly be turned into tail calls. // f->addFnAttr(llvm::Attribute::Speculatable); -#if LLVM_VERSION_MAJOR == 9 +#if LLVM_VERSION_MAJOR >= 9 f->addFnAttr(llvm::Attribute::NoFree); f->addFnAttr(llvm::Attribute::WillReturn); #endif } namespace { -#if LLVM_VERSION_MAJOR == 9 +#if LLVM_VERSION_MAJOR >= 9 using FunctionCallee = llvm::FunctionCallee; @@ -1258,7 +1281,7 @@ struct FunctionCallee { }; #else -#error Only LLVM versions 8 or 9 are supported. +#error Only LLVM versions 8 through 12 are supported. #endif } // namespace @@ -1282,48 +1305,50 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } break; #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 8) { \ - fname = "Sleef_" + std::string(name) + "8"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 8) { \ + fname = "Sleef_" + std::string(name) + "8"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_UNARY_MATH_CASE(kLog10, "log10f", FloatTy_) @@ -1353,54 +1378,56 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { #undef SIMD_UNARY_MATH_CASE #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 8) { \ - fname = "Sleef_" + std::string(name) + "8"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 8) { \ + fname = "Sleef_" + std::string(name) + "8"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_BINARY_MATH_CASE(kAtan2, "atan2f", FloatTy_) @@ -1426,48 +1453,50 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } else if (v->dtype().scalar_type() == ScalarType::Double) { switch (v->op_type()) { #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "d4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "d4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_UNARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = \ + llvm::VectorType::get(type, ElementCount(v->dtype().lanes())); \ + callee = module_->getOrInsertFunction( \ + fname, llvm::FunctionType::get(vecType, {vecType}, false), {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_UNARY_MATH_CASE(kLog10, "log10", DoubleTy_) @@ -1508,54 +1537,56 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) { } break; #if defined(__AVX__) && !defined(_MSC_VER) -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 4) { \ - fname = "Sleef_" + std::string(name) + "d4"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 4) { \ + fname = "Sleef_" + std::string(name) + "d4"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #else -#define SIMD_BINARY_MATH_CASE(enum, name, type) \ - case enum: { \ - FunctionCallee callee; \ - std::string fname; \ - if (v->dtype().lanes() == 2) { \ - fname = "Sleef_" + std::string(name) + "d2"; \ - llvm::Type* vecType = llvm::VectorType::get(type, v->dtype().lanes()); \ - callee = module_->getOrInsertFunction( \ - fname, \ - llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ - {}); \ - call_simd_sleef = true; \ - } else { \ - callee = module_->getOrInsertFunction( \ - name, llvm::FunctionType::get(type, {type, type}, false), {}); \ - } \ - call_ty = callee.getFunctionType(); \ - call_fn = callee.getCallee(); \ - applyMathFunctionAttributes(llvm::cast(call_fn)); \ +#define SIMD_BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + FunctionCallee callee; \ + std::string fname; \ + auto element_count = ElementCount(v->dtype().lanes()); \ + if (v->dtype().lanes() == 2) { \ + fname = "Sleef_" + std::string(name) + "d2"; \ + llvm::Type* vecType = llvm::VectorType::get(type, element_count); \ + callee = module_->getOrInsertFunction( \ + fname, \ + llvm::FunctionType::get(vecType, {vecType, vecType}, false), \ + {}); \ + call_simd_sleef = true; \ + } else { \ + callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + } \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ } break; #endif SIMD_BINARY_MATH_CASE(kAtan2, "atan2", DoubleTy_) diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index 8a8e2bf48513..d80e83ce0eff 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -30,7 +29,7 @@ namespace orc { // Lightly modified implementation from LLVM's Kaleidoscope JIT tutorial: // https://llvm.org/docs/tutorial/BuildingAJIT1.html -#if LLVM_VERSION_MAJOR == 9 +#if LLVM_VERSION_MAJOR >= 9 && LLVM_VERSION_MAJOR <= 12 class TORCH_API PytorchLLVMJITImpl { private: std::unique_ptr LLJ; @@ -40,401 +39,406 @@ class TORCH_API PytorchLLVMJITImpl { auto ProcSymbolsGenerator = cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( LLJ->getDataLayout().getGlobalPrefix())); - LLJ->getMainJITDylib().setGenerator(std::move(ProcSymbolsGenerator)); + auto& JD = LLJ->getMainJITDylib(); +#if LLVM_VERSION_MAJOR == 9 + JD.setGenerator(std::move(ProcSymbolsGenerator)); +#else + JD.addGenerator(std::move(ProcSymbolsGenerator)); +#endif // Handle platform-specific symbol mangling MangleAndInterner Mangle(LLJ->getExecutionSession(), LLJ->getDataLayout()); // Register implementations of intrinsics - cantFail(LLJ->defineAbsolute( - *Mangle("log10f"), {llvm::pointerToJITTargetAddress(&log10f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("log1pf"), {llvm::pointerToJITTargetAddress(&log1pf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("logf"), {llvm::pointerToJITTargetAddress(&logf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("log2f"), {llvm::pointerToJITTargetAddress(&log2f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("expf"), {llvm::pointerToJITTargetAddress(&expf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("erff"), {llvm::pointerToJITTargetAddress(&erff), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("cosf"), {llvm::pointerToJITTargetAddress(&cosf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sinf"), {llvm::pointerToJITTargetAddress(&sinf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("tanf"), {llvm::pointerToJITTargetAddress(&tanf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("acosf"), {llvm::pointerToJITTargetAddress(&acosf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("asinf"), {llvm::pointerToJITTargetAddress(&asinf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("atanf"), {llvm::pointerToJITTargetAddress(&atanf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("coshf"), {llvm::pointerToJITTargetAddress(&coshf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sinhf"), {llvm::pointerToJITTargetAddress(&sinhf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("tanhf"), {llvm::pointerToJITTargetAddress(&tanhf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("sqrtf"), {llvm::pointerToJITTargetAddress(&sqrtf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("fabsf"), {llvm::pointerToJITTargetAddress(&fabsf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("floorf"), {llvm::pointerToJITTargetAddress(&floorf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("ceilf"), {llvm::pointerToJITTargetAddress(&ceilf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("roundf"), {llvm::pointerToJITTargetAddress(&roundf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("truncf"), {llvm::pointerToJITTargetAddress(&truncf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("atan2f"), {llvm::pointerToJITTargetAddress(&atan2f), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("fmodf"), {llvm::pointerToJITTargetAddress(&fmodf), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("remainderf"), - {llvm::pointerToJITTargetAddress(&remainderf), {}})); - - // FP32 Sleef functions -- SSE - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosf4"), - {llvm::pointerToJITTargetAddress(&Sleef_acosf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asinf4"), - {llvm::pointerToJITTargetAddress(&Sleef_asinf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atanf4"), - {llvm::pointerToJITTargetAddress(&Sleef_atanf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosf4"), - {llvm::pointerToJITTargetAddress(&Sleef_cosf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanf4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshf4"), - {llvm::pointerToJITTargetAddress(&Sleef_coshf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhf4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erff4"), - {llvm::pointerToJITTargetAddress(&Sleef_erff4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcf4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcf4_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expf4"), - {llvm::pointerToJITTargetAddress(&Sleef_expf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1f4"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logf4"), - {llvm::pointerToJITTargetAddress(&Sleef_logf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2f4"), - {llvm::pointerToJITTargetAddress(&Sleef_log2f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10f4"), - {llvm::pointerToJITTargetAddress(&Sleef_log10f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pf4"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtf4"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtf4_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsf4"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floorf4"), - {llvm::pointerToJITTargetAddress(&Sleef_floorf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceilf4"), - {llvm::pointerToJITTargetAddress(&Sleef_ceilf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncf4"), - {llvm::pointerToJITTargetAddress(&Sleef_truncf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundf4"), - {llvm::pointerToJITTargetAddress(&Sleef_roundf4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammaf4"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammaf4_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2f4"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2f4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powf4"), - {llvm::pointerToJITTargetAddress(&Sleef_powf4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodf4"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodf4), {}})); - - // FP32 Sleef functions -- AVX2 + cantFail(JD.define(absoluteSymbols({ + {Mangle("log10f"), + {llvm::pointerToJITTargetAddress(&log10f), JITSymbolFlags::None}}, + {Mangle("log1pf"), + {llvm::pointerToJITTargetAddress(&log1pf), JITSymbolFlags::None}}, + {Mangle("logf"), + {llvm::pointerToJITTargetAddress(&logf), JITSymbolFlags::None}}, + {Mangle("log2f"), + {llvm::pointerToJITTargetAddress(&log2f), JITSymbolFlags::None}}, + {Mangle("expf"), + {llvm::pointerToJITTargetAddress(&expf), JITSymbolFlags::None}}, + {Mangle("erff"), + {llvm::pointerToJITTargetAddress(&erff), JITSymbolFlags::None}}, + {Mangle("cosf"), + {llvm::pointerToJITTargetAddress(&cosf), JITSymbolFlags::None}}, + {Mangle("sinf"), + {llvm::pointerToJITTargetAddress(&sinf), JITSymbolFlags::None}}, + {Mangle("tanf"), + {llvm::pointerToJITTargetAddress(&tanf), JITSymbolFlags::None}}, + {Mangle("acosf"), + {llvm::pointerToJITTargetAddress(&acosf), JITSymbolFlags::None}}, + {Mangle("asinf"), + {llvm::pointerToJITTargetAddress(&asinf), JITSymbolFlags::None}}, + {Mangle("atanf"), + {llvm::pointerToJITTargetAddress(&atanf), JITSymbolFlags::None}}, + {Mangle("coshf"), + {llvm::pointerToJITTargetAddress(&coshf), JITSymbolFlags::None}}, + {Mangle("sinhf"), + {llvm::pointerToJITTargetAddress(&sinhf), JITSymbolFlags::None}}, + {Mangle("tanhf"), + {llvm::pointerToJITTargetAddress(&tanhf), JITSymbolFlags::None}}, + {Mangle("sqrtf"), + {llvm::pointerToJITTargetAddress(&sqrtf), JITSymbolFlags::None}}, + {Mangle("fabsf"), + {llvm::pointerToJITTargetAddress(&fabsf), JITSymbolFlags::None}}, + {Mangle("floorf"), + {llvm::pointerToJITTargetAddress(&floorf), JITSymbolFlags::None}}, + {Mangle("ceilf"), + {llvm::pointerToJITTargetAddress(&ceilf), JITSymbolFlags::None}}, + {Mangle("roundf"), + {llvm::pointerToJITTargetAddress(&roundf), JITSymbolFlags::None}}, + {Mangle("truncf"), + {llvm::pointerToJITTargetAddress(&truncf), JITSymbolFlags::None}}, + {Mangle("atan2f"), + {llvm::pointerToJITTargetAddress(&atan2f), JITSymbolFlags::None}}, + {Mangle("fmodf"), + {llvm::pointerToJITTargetAddress(&fmodf), JITSymbolFlags::None}}, + {Mangle("remainderf"), + {llvm::pointerToJITTargetAddress(&remainderf), + JITSymbolFlags::None}}, + + // FP32 Sleef functions -- SSE + {Mangle("Sleef_acosf4"), + {llvm::pointerToJITTargetAddress(&Sleef_acosf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_asinf4"), + {llvm::pointerToJITTargetAddress(&Sleef_asinf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atanf4"), + {llvm::pointerToJITTargetAddress(&Sleef_atanf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_cosf4"), + {llvm::pointerToJITTargetAddress(&Sleef_cosf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinf4"), + {llvm::pointerToJITTargetAddress(&Sleef_sinf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanf4"), + {llvm::pointerToJITTargetAddress(&Sleef_tanf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_coshf4"), + {llvm::pointerToJITTargetAddress(&Sleef_coshf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinhf4"), + {llvm::pointerToJITTargetAddress(&Sleef_sinhf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanhf4"), + {llvm::pointerToJITTargetAddress(&Sleef_tanhf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erff4"), + {llvm::pointerToJITTargetAddress(&Sleef_erff4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfcf4"), + {llvm::pointerToJITTargetAddress(&Sleef_erfcf4_u15), + JITSymbolFlags::None}}, + {Mangle("Sleef_expf4"), + {llvm::pointerToJITTargetAddress(&Sleef_expf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_expm1f4"), + {llvm::pointerToJITTargetAddress(&Sleef_expm1f4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_logf4"), + {llvm::pointerToJITTargetAddress(&Sleef_logf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log2f4"), + {llvm::pointerToJITTargetAddress(&Sleef_log2f4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log10f4"), + {llvm::pointerToJITTargetAddress(&Sleef_log10f4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log1pf4"), + {llvm::pointerToJITTargetAddress(&Sleef_log1pf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sqrtf4"), + {llvm::pointerToJITTargetAddress(&Sleef_sqrtf4_u05), + JITSymbolFlags::None}}, + {Mangle("Sleef_fabsf4"), + {llvm::pointerToJITTargetAddress(&Sleef_fabsf4), + JITSymbolFlags::None}}, + {Mangle("Sleef_floorf4"), + {llvm::pointerToJITTargetAddress(&Sleef_floorf4), + JITSymbolFlags::None}}, + {Mangle("Sleef_ceilf4"), + {llvm::pointerToJITTargetAddress(&Sleef_ceilf4), + JITSymbolFlags::None}}, + {Mangle("Sleef_truncf4"), + {llvm::pointerToJITTargetAddress(&Sleef_truncf4), + JITSymbolFlags::None}}, + {Mangle("Sleef_roundf4"), + {llvm::pointerToJITTargetAddress(&Sleef_roundf4), + JITSymbolFlags::None}}, + {Mangle("Sleef_lgammaf4"), + {llvm::pointerToJITTargetAddress(&Sleef_lgammaf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atan2f4"), + {llvm::pointerToJITTargetAddress(&Sleef_atan2f4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_powf4"), + {llvm::pointerToJITTargetAddress(&Sleef_powf4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_fmodf4"), + {llvm::pointerToJITTargetAddress(&Sleef_fmodf4), + JITSymbolFlags::None}}, + + // FP32 Sleef functions -- AVX2 #if defined(__AVX__) && !defined(_MSC_VER) - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosf8"), - {llvm::pointerToJITTargetAddress(&Sleef_acosf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asinf8"), - {llvm::pointerToJITTargetAddress(&Sleef_asinf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atanf8"), - {llvm::pointerToJITTargetAddress(&Sleef_atanf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosf8"), - {llvm::pointerToJITTargetAddress(&Sleef_cosf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sinf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanf8"), - {llvm::pointerToJITTargetAddress(&Sleef_tanf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshf8"), - {llvm::pointerToJITTargetAddress(&Sleef_coshf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhf8"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erff8"), - {llvm::pointerToJITTargetAddress(&Sleef_erff8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcf8"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcf8_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expf8"), - {llvm::pointerToJITTargetAddress(&Sleef_expf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1f8"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logf8"), - {llvm::pointerToJITTargetAddress(&Sleef_logf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2f8"), - {llvm::pointerToJITTargetAddress(&Sleef_log2f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10f8"), - {llvm::pointerToJITTargetAddress(&Sleef_log10f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pf8"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtf8"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtf8_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsf8"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floorf8"), - {llvm::pointerToJITTargetAddress(&Sleef_floorf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceilf8"), - {llvm::pointerToJITTargetAddress(&Sleef_ceilf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncf8"), - {llvm::pointerToJITTargetAddress(&Sleef_truncf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundf8"), - {llvm::pointerToJITTargetAddress(&Sleef_roundf8), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammaf8"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammaf8_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2f8"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2f8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powf8"), - {llvm::pointerToJITTargetAddress(&Sleef_powf8_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodf8"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodf8), {}})); + {Mangle("Sleef_acosf8"), + {llvm::pointerToJITTargetAddress(&Sleef_acosf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_asinf8"), + {llvm::pointerToJITTargetAddress(&Sleef_asinf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atanf8"), + {llvm::pointerToJITTargetAddress(&Sleef_atanf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_cosf8"), + {llvm::pointerToJITTargetAddress(&Sleef_cosf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinf8"), + {llvm::pointerToJITTargetAddress(&Sleef_sinf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanf8"), + {llvm::pointerToJITTargetAddress(&Sleef_tanf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_coshf8"), + {llvm::pointerToJITTargetAddress(&Sleef_coshf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinhf8"), + {llvm::pointerToJITTargetAddress(&Sleef_sinhf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanhf8"), + {llvm::pointerToJITTargetAddress(&Sleef_tanhf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erff8"), + {llvm::pointerToJITTargetAddress(&Sleef_erff8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfcf8"), + {llvm::pointerToJITTargetAddress(&Sleef_erfcf8_u15), + JITSymbolFlags::None}}, + {Mangle("Sleef_expf8"), + {llvm::pointerToJITTargetAddress(&Sleef_expf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_expm1f8"), + {llvm::pointerToJITTargetAddress(&Sleef_expm1f8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_logf8"), + {llvm::pointerToJITTargetAddress(&Sleef_logf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log2f8"), + {llvm::pointerToJITTargetAddress(&Sleef_log2f8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log10f8"), + {llvm::pointerToJITTargetAddress(&Sleef_log10f8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log1pf8"), + {llvm::pointerToJITTargetAddress(&Sleef_log1pf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sqrtf8"), + {llvm::pointerToJITTargetAddress(&Sleef_sqrtf8_u05), + JITSymbolFlags::None}}, + {Mangle("Sleef_fabsf8"), + {llvm::pointerToJITTargetAddress(&Sleef_fabsf8), + JITSymbolFlags::None}}, + {Mangle("Sleef_floorf8"), + {llvm::pointerToJITTargetAddress(&Sleef_floorf8), + JITSymbolFlags::None}}, + {Mangle("Sleef_ceilf8"), + {llvm::pointerToJITTargetAddress(&Sleef_ceilf8), + JITSymbolFlags::None}}, + {Mangle("Sleef_truncf8"), + {llvm::pointerToJITTargetAddress(&Sleef_truncf8), + JITSymbolFlags::None}}, + {Mangle("Sleef_roundf8"), + {llvm::pointerToJITTargetAddress(&Sleef_roundf8), + JITSymbolFlags::None}}, + {Mangle("Sleef_lgammaf8"), + {llvm::pointerToJITTargetAddress(&Sleef_lgammaf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atan2f8"), + {llvm::pointerToJITTargetAddress(&Sleef_atan2f8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_powf8"), + {llvm::pointerToJITTargetAddress(&Sleef_powf8_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_fmodf8"), + {llvm::pointerToJITTargetAddress(&Sleef_fmodf8), + JITSymbolFlags::None}}, #endif - // FP64 Sleef functions -- SSE - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosd2"), - {llvm::pointerToJITTargetAddress(&Sleef_acosd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asind2"), - {llvm::pointerToJITTargetAddress(&Sleef_asind2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atand2"), - {llvm::pointerToJITTargetAddress(&Sleef_atand2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosd2"), - {llvm::pointerToJITTargetAddress(&Sleef_cosd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sind2"), - {llvm::pointerToJITTargetAddress(&Sleef_sind2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tand2"), - {llvm::pointerToJITTargetAddress(&Sleef_tand2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshd2"), - {llvm::pointerToJITTargetAddress(&Sleef_coshd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhd2"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhd2"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfd2"), - {llvm::pointerToJITTargetAddress(&Sleef_erfd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcd2"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcd2_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expd2"), - {llvm::pointerToJITTargetAddress(&Sleef_expd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1d2"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logd2"), - {llvm::pointerToJITTargetAddress(&Sleef_logd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2d2"), - {llvm::pointerToJITTargetAddress(&Sleef_log2d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10d2"), - {llvm::pointerToJITTargetAddress(&Sleef_log10d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pd2"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtd2"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtd2_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsd2"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floord2"), - {llvm::pointerToJITTargetAddress(&Sleef_floord2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceild2"), - {llvm::pointerToJITTargetAddress(&Sleef_ceild2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncd2"), - {llvm::pointerToJITTargetAddress(&Sleef_truncd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundd2"), - {llvm::pointerToJITTargetAddress(&Sleef_roundd2), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammad2"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammad2_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2d2"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2d2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powd2"), - {llvm::pointerToJITTargetAddress(&Sleef_powd2_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodd2"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodd2), {}})); - - // FP64 Sleef functions -- AVX2 + // FP64 Sleef functions -- SSE + {Mangle("Sleef_acosd2"), + {llvm::pointerToJITTargetAddress(&Sleef_acosd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_asind2"), + {llvm::pointerToJITTargetAddress(&Sleef_asind2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atand2"), + {llvm::pointerToJITTargetAddress(&Sleef_atand2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_cosd2"), + {llvm::pointerToJITTargetAddress(&Sleef_cosd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sind2"), + {llvm::pointerToJITTargetAddress(&Sleef_sind2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tand2"), + {llvm::pointerToJITTargetAddress(&Sleef_tand2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_coshd2"), + {llvm::pointerToJITTargetAddress(&Sleef_coshd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinhd2"), + {llvm::pointerToJITTargetAddress(&Sleef_sinhd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanhd2"), + {llvm::pointerToJITTargetAddress(&Sleef_tanhd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfd2"), + {llvm::pointerToJITTargetAddress(&Sleef_erfd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfcd2"), + {llvm::pointerToJITTargetAddress(&Sleef_erfcd2_u15), + JITSymbolFlags::None}}, + {Mangle("Sleef_expd2"), + {llvm::pointerToJITTargetAddress(&Sleef_expd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_expm1d2"), + {llvm::pointerToJITTargetAddress(&Sleef_expm1d2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_logd2"), + {llvm::pointerToJITTargetAddress(&Sleef_logd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log2d2"), + {llvm::pointerToJITTargetAddress(&Sleef_log2d2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log10d2"), + {llvm::pointerToJITTargetAddress(&Sleef_log10d2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log1pd2"), + {llvm::pointerToJITTargetAddress(&Sleef_log1pd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sqrtd2"), + {llvm::pointerToJITTargetAddress(&Sleef_sqrtd2_u05), + JITSymbolFlags::None}}, + {Mangle("Sleef_fabsd2"), + {llvm::pointerToJITTargetAddress(&Sleef_fabsd2), + JITSymbolFlags::None}}, + {Mangle("Sleef_floord2"), + {llvm::pointerToJITTargetAddress(&Sleef_floord2), + JITSymbolFlags::None}}, + {Mangle("Sleef_ceild2"), + {llvm::pointerToJITTargetAddress(&Sleef_ceild2), + JITSymbolFlags::None}}, + {Mangle("Sleef_truncd2"), + {llvm::pointerToJITTargetAddress(&Sleef_truncd2), + JITSymbolFlags::None}}, + {Mangle("Sleef_roundd2"), + {llvm::pointerToJITTargetAddress(&Sleef_roundd2), + JITSymbolFlags::None}}, + {Mangle("Sleef_lgammad2"), + {llvm::pointerToJITTargetAddress(&Sleef_lgammad2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atan2d2"), + {llvm::pointerToJITTargetAddress(&Sleef_atan2d2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_powd2"), + {llvm::pointerToJITTargetAddress(&Sleef_powd2_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_fmodd2"), + {llvm::pointerToJITTargetAddress(&Sleef_fmodd2), + JITSymbolFlags::None}}, + + // FP64 Sleef functions -- AVX2 #if defined(__AVX__) && !defined(_MSC_VER) - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_acosd4"), - {llvm::pointerToJITTargetAddress(&Sleef_acosd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_asind4"), - {llvm::pointerToJITTargetAddress(&Sleef_asind4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atand4"), - {llvm::pointerToJITTargetAddress(&Sleef_atand4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_cosd4"), - {llvm::pointerToJITTargetAddress(&Sleef_cosd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sind4"), - {llvm::pointerToJITTargetAddress(&Sleef_sind4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tand4"), - {llvm::pointerToJITTargetAddress(&Sleef_tand4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_coshd4"), - {llvm::pointerToJITTargetAddress(&Sleef_coshd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sinhd4"), - {llvm::pointerToJITTargetAddress(&Sleef_sinhd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_tanhd4"), - {llvm::pointerToJITTargetAddress(&Sleef_tanhd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfd4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_erfcd4"), - {llvm::pointerToJITTargetAddress(&Sleef_erfcd4_u15), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expd4"), - {llvm::pointerToJITTargetAddress(&Sleef_expd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_expm1d4"), - {llvm::pointerToJITTargetAddress(&Sleef_expm1d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_logd4"), - {llvm::pointerToJITTargetAddress(&Sleef_logd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log2d4"), - {llvm::pointerToJITTargetAddress(&Sleef_log2d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log10d4"), - {llvm::pointerToJITTargetAddress(&Sleef_log10d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_log1pd4"), - {llvm::pointerToJITTargetAddress(&Sleef_log1pd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_sqrtd4"), - {llvm::pointerToJITTargetAddress(&Sleef_sqrtd4_u05), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fabsd4"), - {llvm::pointerToJITTargetAddress(&Sleef_fabsd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_floord4"), - {llvm::pointerToJITTargetAddress(&Sleef_floord4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_ceild4"), - {llvm::pointerToJITTargetAddress(&Sleef_ceild4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_truncd4"), - {llvm::pointerToJITTargetAddress(&Sleef_truncd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_roundd4"), - {llvm::pointerToJITTargetAddress(&Sleef_roundd4), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_lgammad4"), - {llvm::pointerToJITTargetAddress(&Sleef_lgammad4_u10), {}})); - - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_atan2d4"), - {llvm::pointerToJITTargetAddress(&Sleef_atan2d4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_powd4"), - {llvm::pointerToJITTargetAddress(&Sleef_powd4_u10), {}})); - cantFail(LLJ->defineAbsolute( - *Mangle("Sleef_fmodd4"), - {llvm::pointerToJITTargetAddress(&Sleef_fmodd4), {}})); + {Mangle("Sleef_acosd4"), + {llvm::pointerToJITTargetAddress(&Sleef_acosd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_asind4"), + {llvm::pointerToJITTargetAddress(&Sleef_asind4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atand4"), + {llvm::pointerToJITTargetAddress(&Sleef_atand4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_cosd4"), + {llvm::pointerToJITTargetAddress(&Sleef_cosd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sind4"), + {llvm::pointerToJITTargetAddress(&Sleef_sind4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tand4"), + {llvm::pointerToJITTargetAddress(&Sleef_tand4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_coshd4"), + {llvm::pointerToJITTargetAddress(&Sleef_coshd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sinhd4"), + {llvm::pointerToJITTargetAddress(&Sleef_sinhd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_tanhd4"), + {llvm::pointerToJITTargetAddress(&Sleef_tanhd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfd4"), + {llvm::pointerToJITTargetAddress(&Sleef_erfd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_erfcd4"), + {llvm::pointerToJITTargetAddress(&Sleef_erfcd4_u15), + JITSymbolFlags::None}}, + {Mangle("Sleef_expd4"), + {llvm::pointerToJITTargetAddress(&Sleef_expd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_expm1d4"), + {llvm::pointerToJITTargetAddress(&Sleef_expm1d4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_logd4"), + {llvm::pointerToJITTargetAddress(&Sleef_logd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log2d4"), + {llvm::pointerToJITTargetAddress(&Sleef_log2d4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log10d4"), + {llvm::pointerToJITTargetAddress(&Sleef_log10d4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_log1pd4"), + {llvm::pointerToJITTargetAddress(&Sleef_log1pd4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_sqrtd4"), + {llvm::pointerToJITTargetAddress(&Sleef_sqrtd4_u05), + JITSymbolFlags::None}}, + {Mangle("Sleef_fabsd4"), + {llvm::pointerToJITTargetAddress(&Sleef_fabsd4), + JITSymbolFlags::None}}, + {Mangle("Sleef_floord4"), + {llvm::pointerToJITTargetAddress(&Sleef_floord4), + JITSymbolFlags::None}}, + {Mangle("Sleef_ceild4"), + {llvm::pointerToJITTargetAddress(&Sleef_ceild4), + JITSymbolFlags::None}}, + {Mangle("Sleef_truncd4"), + {llvm::pointerToJITTargetAddress(&Sleef_truncd4), + JITSymbolFlags::None}}, + {Mangle("Sleef_roundd4"), + {llvm::pointerToJITTargetAddress(&Sleef_roundd4), + JITSymbolFlags::None}}, + {Mangle("Sleef_lgammad4"), + {llvm::pointerToJITTargetAddress(&Sleef_lgammad4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_atan2d4"), + {llvm::pointerToJITTargetAddress(&Sleef_atan2d4_u10), + JITSymbolFlags::None}}, + {Mangle("Sleef_powd4"), + {llvm::pointerToJITTargetAddress(&Sleef_powd4_u10), + JITSymbolFlags::None}}, + { + Mangle("Sleef_fmodd4"), { + llvm::pointerToJITTargetAddress(&Sleef_fmodd4), JITSymbolFlags::None + } + } #endif + }))); } Error addModule(std::unique_ptr M, std::unique_ptr C) { @@ -563,7 +567,7 @@ const DataLayout& PytorchLLVMJIT::getDataLayout() { } #else // LLVM_VERSION_MAJOR -#error Only LLVM versions 8 or 9 are supported. +#error Only LLVM versions 8 through 12 are supported. #endif } // end namespace orc diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 422de1d1bf25..71f5e2152f3b 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -647,17 +647,18 @@ class FunctionInliner : public IRMutator { std::unordered_map> random_bindings_; }; -void LoopNest::computeInline(Stmt* s) { +bool LoopNest::computeInline(Stmt* s) { auto* s_store = dynamic_cast(s); if (s_store == nullptr) { throw std::logic_error("Could not find buffer producer to inline"); } - computeInline(s_store->buf()); + return computeInline(s_store->buf()); } -void LoopNest::computeInline(const Buf* b) { +bool LoopNest::computeInline(const Buf* b) { if (output_bufs_.count(b)) { - throw std::logic_error("Can't inline producers of output Tensors"); + // Cannot inline producers of output Tensors + return false; } // Find producers. @@ -667,20 +668,24 @@ void LoopNest::computeInline(const Buf* b) { if (s->buf() == b) { auto reductions = NodeFinder::find(s); if (!reductions.empty()) { - throw std::logic_error("cannot inline a reduction computation"); + // Cannot inline a reduction computation + return false; } if (relevant_store != nullptr) { - throw std::logic_error("cannot inline Buf with multiple Tensors"); + // Cannot inline Buf with multiple Tensors + return false; } relevant_store = s; } } + TORCH_INTERNAL_ASSERT(relevant_store); FunctionInliner inliner(relevant_store); root_stmt_ = root_stmt_->accept_mutator(&inliner); // No longer computing this intermediate tensor, so don't alloc it. intermediate_bufs_.erase(b); + return true; } // TODO: Unify with DepTracker diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 3ba14a63abf4..cbebf9551d25 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -51,8 +51,8 @@ class TORCH_API LoopNest { void vectorize(Stmt*); - void computeInline(Stmt* s); - void computeInline(const Buf* b); + bool computeInline(Stmt* s); + bool computeInline(const Buf* b); static void splitWithTail(For* f, int factor); static void splitWithTail( diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index abdaea147c00..55c1926b3541 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 7df518f404c5..7bdd493406c5 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -17,6 +17,10 @@ namespace py = pybind11; +// This makes intrusive_ptr to be available as a custom pybind11 holder type, see +// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers +PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr, true); + namespace pybind11 { namespace detail { // torch.autograd.Variable <-> at::Tensor conversions (without unwrapping) diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index e50792c47bec..b0b81a9517da 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -601,7 +601,7 @@ inline c10::complex PythonArgs::toComplex(int i) { inline c10::complex PythonArgs::toComplexWithDefault(int i, c10::complex default_value) { if (!args[i]) return default_value; - return toDouble(i); + return toComplex(i); } inline bool PythonArgs::toBool(int i) { diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 32353c5dc023..3090721c20ff 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1357,7 +1357,8 @@ def all_gather_object(object_list, obj, group=group.WORLD): object_list (list[Any]): Output list. It should be correctly sized as the size of the group for this collective and will contain the output. object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. If the calling rank is part of this group, the output of the @@ -1369,6 +1370,13 @@ def all_gather_object(object_list, obj, group=group.WORLD): collective since it does not provide an ``async_op`` handle and thus will be a blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + .. warning:: :func:`all_gather_object` uses ``pickle`` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data @@ -1380,16 +1388,19 @@ def all_gather_object(object_list, obj, group=group.WORLD): input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) - my_rank = get_rank() is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.cuda.current_device() + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=int, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -1399,8 +1410,8 @@ def all_gather_object(object_list, obj, group=group.WORLD): # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1427,7 +1438,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): collective and will contain the output. Must be ``None`` on non-dst ranks. (default is ``None``) dst (int, optional): Destination rank. (default is 0) - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: None. On the ``dst`` rank, ``object_gather_list`` will contain the @@ -1453,20 +1465,22 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): _validate_output_list_for_rank(my_rank, dst, object_gather_list) input_tensor, local_size = _object_to_tensor(obj) group_backend = get_backend(group) + current_device = torch.device("cpu") is_nccl_backend = group_backend == Backend.NCCL if is_nccl_backend: - input_tensor, local_size = input_tensor.to(my_rank), local_size.to(my_rank) + current_device = torch.cuda.current_device() + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros(group_size, dtype=int).to( - my_rank if is_nccl_backend else "cpu" - ) + object_sizes_tensor = torch.zeros(group_size, dtype=int, device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] - # Allgather tensor sizes. An all-gather is needed here despite this being a gather, - # since each rank needs to broadcast a tensor of the same (maximal) size. + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. all_gather(object_size_list, local_size, group=group) max_object_size = max(object_size_list) # Resize tensor to max size across all ranks. @@ -1474,8 +1488,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=group.WORLD): # Avoid populating output tensors if the result won't be gathered on this rank. if my_rank == dst: coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8 - ).to(my_rank if is_nccl_backend else "cpu") + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] @@ -1508,15 +1522,23 @@ def broadcast_object_list(object_list, src, group=group.WORLD): Each object must be picklable. Only objects on the ``src`` rank will be broadcast, but each rank must provide lists of equal sizes. src (int): Source rank from which to broadcast ``object_list``. - group: (ProcessGroup, optional): The process group to work on. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Returns: ``None``. If rank is part of the group, ``object_list`` will contain the broadcasted objects from ``src`` rank. - .. note:: Note that this API differs slightly from the broadcast collective - since it does not provide an ``async_op`` handle and thus will be a - blocking call. + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. .. warning:: :func:`broadcast_object_list` uses ``pickle`` module implicitly, which @@ -1537,8 +1559,14 @@ def broadcast_object_list(object_list, src, group=group.WORLD): group_backend = get_backend(group) is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") if is_nccl_backend: - object_sizes_tensor = object_sizes_tensor.to(my_rank) + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.cuda.current_device() + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -1550,7 +1578,7 @@ def broadcast_object_list(object_list, src, group=group.WORLD): object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) if is_nccl_backend: - object_tensor = object_tensor.to(my_rank) + object_tensor = object_tensor.to(current_device) broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. offset = 0 diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 225cb4842bd1..e0aebf87ae84 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -17,6 +17,7 @@ import torch.distributed.rpc as rpc from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator +from torch.distributed.rpc.utils import _parse_remote_device from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle @@ -64,8 +65,7 @@ def _raise_not_supported(name): class _RemoteModule(nn.Module): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, @@ -100,8 +100,10 @@ def __init__( ``def forward_async(input: Tensor) -> Future[Tensor]:``. Arguments: - on (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -132,7 +134,7 @@ def __init__( >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", "cpu", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -155,18 +157,22 @@ def __init__( args = args if args is not None else () kwargs = kwargs if kwargs is not None else {} - self.on = on + self.on, self.device = _parse_remote_device(remote_device) if _module_interface_cls is not None: # Users reply on this field to know if this generated RemoteModule is TorchScript-able. self.is_scriptable = True # Instantiate template on remote side. - fut = rpc.rpc_async(on, _instantiate_template, (_module_interface_cls,)) + fut = rpc.rpc_async( + self.on, _instantiate_template, (_module_interface_cls,) + ) # Instantiate template on local side. - generated_module = instantiator.instantiate_scriptable_remote_module_template( - _module_interface_cls + generated_module = ( + instantiator.instantiate_scriptable_remote_module_template( + _module_interface_cls + ) ) generated_methods = generated_module._generated_methods @@ -178,9 +184,9 @@ def __init__( # Create the module on the remote side. self.module_rref = rpc.rpc_sync( - on, + self.on, _create_module, - (module_cls, args, kwargs, device, _module_interface_cls), + (module_cls, args, kwargs, self.device, _module_interface_cls), ) # Install generated methods. @@ -329,8 +335,10 @@ class RemoteModule(_RemoteModule): ``def forward_async(input: Tensor) -> Future[Tensor]:``. Arguments: - to (str or WorkerInfo): id or name of the destination worker. - device (torch.device): Device on the destination worker where we‘d like to place this module. + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): >>> def forward(input): @@ -357,7 +365,7 @@ class RemoteModule(_RemoteModule): >>> >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> remote_linear_module = RemoteModule( - >>> "worker1", nn.Linear, args=(20, 30), + >>> "worker1/cpu", nn.Linear, args=(20, 30), >>> ) >>> input = torch.randn(128, 20) >>> ret_fut = remote_linear_module.forward_async(input) @@ -374,10 +382,9 @@ class RemoteModule(_RemoteModule): def __init__( self, - on: str, - device: torch.device, + remote_device: str, module_cls: nn.Module, args: Tuple = None, kwargs: Dict[str, Any] = None, ): - super().__init__(on, device, module_cls, args, kwargs) + super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/rpc/utils.py b/torch/distributed/rpc/utils.py new file mode 100644 index 000000000000..15924c4a72f0 --- /dev/null +++ b/torch/distributed/rpc/utils.py @@ -0,0 +1,37 @@ +def _parse_remote_device(remote_device: str): + r""" + Parses the remote device. + + Arguments: + remote_device (str): Device on the destination worker where we‘d like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + + Returns: + A workername and a device. + """ + fields = remote_device.split("/") + if len(fields) == 2: + [on, device] = fields + elif len(fields) == 1: + on = fields[0] + device = "cpu" + else: + raise RuntimeError( + "Could not parse remote_device: {}. The valid format is '/'".format( + remote_device + ) + ) + + # Since the workername in the input remote device won't be validated until the created remote module is executed, + # only do some very basic sanity check on workername at the module creation time. + # As currently there is no regex to describe the format of workername, just check whether the workername is empty. + if not on: + raise RuntimeError( + "The workername in remote_device '{}' cannot be empty. The valid format is '/'".format( + remote_device + ) + ) + + return on, device diff --git a/torch/functional.py b/torch/functional.py index e31ec40d63d7..cb9be8117fa8 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1449,7 +1449,12 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` - - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)` + - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. + ``pivots`` stores all the intermediate transpositions of rows. + The final permutation ``perm`` could be reconstructed by + applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, + where ``perm`` is initially the identity permutation of :math:`m` elements + (essentially this is what :func:`torch.lu_unpack` is doing). - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of size :math:`(*)` where non-zero values indicate whether factorization for the matrix or diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 792a905432a5..c7fbd6fbf0ea 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -2,7 +2,7 @@ r''' **This feature is experimental and its stability is not currently guaranteed. Proceed at your own risk** -FX (Functional Transformations) is a toolkit for capturing and transforming functional PyTorch programs. It +FX is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's `forward` method symbolically and record those operations in the FX intermediate representation. diff --git a/torch/fx/experimental/Partitioner.py b/torch/fx/experimental/Partitioner.py index fadbac42cdd0..9c1bdaaa1335 100644 --- a/torch/fx/experimental/Partitioner.py +++ b/torch/fx/experimental/Partitioner.py @@ -64,6 +64,17 @@ def recalculate_mem_size(self): for node in self.nodes: self.used_mem_bytes += get_extra_size_of(node, self.nodes) + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {'placeholder', 'get_attr'}: + self.nodes.add(n) + self.nodes.add(node) + + class PartitionResult(NamedTuple): """NameTuple used for returning DAG and a new graph module """ @@ -106,6 +117,13 @@ def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: raise RuntimeError('node has no size_bytes attr') return total_size_of_input_nodes +def calculate_mem_bytes_needed(p1, p2): + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + class Partitioner: """A graph module may not fit into one device. Partitioner class helps cut one graph into subgraphs (partitions), @@ -147,14 +165,17 @@ def partition_graph( if node.op == 'output': break total_size_of_graph += node.size_bytes.total_size - if total_size_of_graph <= available_mem_bytes: + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + if total_size_of_graph <= device_with_max_mem.available_mem_bytes: self.find_single_partition(total_size_of_graph) - elif total_size_of_graph > len(self.devices) * available_mem_bytes: + elif total_size_of_graph > sum([d.available_mem_bytes for d in self.devices]): raise RuntimeError('Devices have no enough memory for the module') else: - if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): - raise RuntimeError('All devices must have same memory size!') if partitioner_config.is_sparse_nn: + if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): + raise RuntimeError('All devices must have same memory size!') + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition self.sparse_nn_partition(available_mem_bytes) else: self.size_based_partition(available_mem_bytes) @@ -174,57 +195,141 @@ def find_single_partition(self, total_size_of_graph) -> None: self.node_to_partitions[node] = partition_0.partition_id partition_0.nodes.add(node) partition_0.used_mem_bytes = total_size_of_graph - partition_0.logical_device_ids = [self.devices[0].logical_id] + partition_0.logical_device_ids = [0] return def size_based_partition(self, available_mem_bytes: int) -> None: - """This method partitions the graph based on memory size. - We assume all devices have the same memory size. + """This method is to partition the graph based on memory size. + It uses greedy approach. The result may not be the best. The basic idea is: - First, create a new partition. - Then traverse the graph through self.graph_module.graph.nodes - The traversal only focuses on op nodes - (call_function, call_module, call_method). - The placeholder nodes (placeholder) and constant nodes (get_attr) are skipped. - A placeholder (placeholder) or a constant (get_attr) - is added into a partition when it is a input node for a op node. - From one op node to another, check if a op node and its input nodes - can fit into the current partition. - If the current partition is full, create a new one - and continue traversing op nodes. - Then through self.combine_partition_based_on_size(), - partitions will be combined to keep - as less partitions as possible. + Step 1: + Find a device which has enough memory to fit the first node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + Try to combine those single node partitions with the non single node partitions + from Step 1 and Step 2. + If two partitions cannot be combined, but could fit into the same logical device, + Two partitions use the same logical device. """ - # Create the first partition + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device('', -1, -1) + for d in self.devices: + if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed: + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + 'is too large to fit any device') + occupied_devices.append(device) + return device + + def create_single_node_partition(node): + """Create a partition for a single node + """ + partition = self.create_partition() + total_size_needed = get_extra_size_of(node, set()) + partition.add_node(node) + partition.used_mem_bytes = total_size_needed + single_node_partitions.append(partition) + + # Track all single node partitions in Step 3 + single_node_partitions: List[Partition] = [] + # Track all non single node partitions in Step 1 and Step 2 + non_single_node_partitions: List[Partition] = [] + # Track partition and its left mem size + partition_to_left_mem_bytes: Dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: List[Device] = [] partition = self.create_partition() - # Track the used mem for the current partition for node in self.graph_module.graph.nodes: if node.op in {'call_module', 'call_method', 'call_function'}: - total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - # The current node with its inputs cannot fit into the current partition - if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes: - partition = self.create_partition() + # Check if there are devices left + if len(self.partitions) <= len(self.devices): total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) - # The current node may be too large to fit into a whole new partition - if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes: - raise RuntimeError(node.target + 'is too large to fit into a device') - # Add the current node into the current partition - partition.nodes.add(node) - partition.used_mem_bytes += total_size_of_input_nodes - # Find parent partitions and child partitions for each partition. + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + # Update available mem for the current partitio + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into this partition + if partition_to_left_mem_bytes[partition] < total_size_of_input_nodes: + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device left, all the partitions before are non single node partitions + non_single_node_partitions = self.partitions[:] + # Create the first single node partition for the current node + create_single_node_partition(node) + continue + # Some devices are still left + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + partition_to_left_mem_bytes[partition] = device.available_mem_bytes + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + partition.used_mem_bytes += total_size_of_input_nodes + # No device left, create single node partitions + else: + create_single_node_partition(node) self.set_parents_and_children() - # Combine small partitions - self.combine_partitions_based_on_size(self.partitions[:], available_mem_bytes) - # Reassign partition ids and update self.node_to_partitions. + # Check if having single node partitions + # If not, partition is done + if len(single_node_partitions) != 0: + # Going through all single node partitions, + # see if it can be combined with non single node partitions + # or at least fit into a logical device as a standaline partition + while single_node_partitions: + self.get_bfs_level_partition() + # Pick a single node partition + p1 = single_node_partitions.pop(0) + # Set up a flag + find_device = False + # Going through all non single partitions + # and find a device to fit p1 + for p2 in non_single_node_partitions: + # Calculate how many bytes are needed if combining p1 and p2 + mem_size_needed = calculate_mem_bytes_needed(p1, p2) + # Get the available size of p2 + available_mem_bytes = p2.used_mem_bytes + partition_to_left_mem_bytes[p2] + if mem_size_needed <= available_mem_bytes: + # Two partitions can be fit on the same device, + # check if combining them to be one partition + if abs(p1.bfs_level - p2.bfs_level) <= 1: + # Combining p1 and p2 into p0 + p0 = self.combine_two_partitions(p1, p2) + p0.logical_device_ids = p2.logical_device_ids + # Remove p2 from non_single_node_partitions + non_single_node_partitions.remove(p2) + # Add p0 to non_single_partitions + non_single_node_partitions.append(p0) + # Update partition_to_left_mem_bytes + partition_to_left_mem_bytes[p0] = available_mem_bytes - mem_size_needed + del partition_to_left_mem_bytes[p2] + else: + # Cannot combine two partitions, + # but two partitions can fit into p2's device + p1.logical_device_ids = p2.logical_device_ids + # Update partition_to_left_mem_bytes for p2 + partition_to_left_mem_bytes[p2] = available_mem_bytes - mem_size_needed + find_device = True + break + if not find_device: + raise RuntimeError('Lack of Devices') self.reorganize_partitions() - # Check if devices are enough for all partitions - if len(self.partitions) > len(self.devices): - msg = 'Need ' + str(len(self.partitions)) + ' devices, but only ' \ - + str(len(self.devices)) + ' provided' - raise RuntimeError(msg) - for i, partition in enumerate(self.partitions): - partition.logical_device_ids = [self.devices[i].logical_id] return def do_partition(self) -> GraphModule: @@ -310,13 +415,6 @@ def find_partition_to_combine_based_on_size( ) -> Tuple[bool, List[Partition]]: """step 1 in self.combine_partition_based_on_size()""" - def calculate_mem_bytes_needed(p1, p2): - nodes = p1.nodes.union(p2.nodes) - mem_bytes_needed = 0 - for node in nodes: - mem_bytes_needed += get_extra_size_of(node, nodes) - return mem_bytes_needed - find_combination = False smallest_partition = sorted_partitions.pop(0) for p in sorted_partitions[::-1]: @@ -336,42 +434,21 @@ def combine_two_partitions( self, partition_0: Partition, partition_1: Partition, - ) -> None: + ) -> Partition: """Given two partitions, combine them into a new one and remove the previous two partitions from self.partitions """ partition = self.create_partition() partition.nodes = partition_0.nodes.union(partition_1.nodes) - partition.parents = partition_0.parents.union(partition_1.parents) - partition.children = partition_0.children.union(partition_1.children) partition.recalculate_mem_size() - partition.bfs_level = max(partition_0.bfs_level, partition_1.bfs_level) - if partition_0 in partition.children: - partition.children.remove(partition_0) - if partition_0 in partition.parents: - partition.parents.remove(partition_0) - if partition_1 in partition.children: - partition.children.remove(partition_1) - if partition_1 in partition.parents: - partition.parents.remove(partition_1) - # Replace partition_0 and partition_1 with the new partition in children and parents - for p in partition.parents: - if partition_0 in p.children: - p.children.remove(partition_0) - p.children.add(partition) - if partition_1 in p.children: - p.children.remove(partition_1) - p.children.add(partition) - for p in partition.children: - if partition_0 in p.parents: - p.parents.remove(partition_0) - p.parents.add(partition) - if partition_1 in p.parents: - p.parents.remove(partition_1) - p.parents.add(partition_1) self.partitions.remove(partition_0) self.partitions.remove(partition_1) - return + # reset parents and children for all partitions + for partition in self.partitions: + partition.parents = set() + partition.children = set() + self.set_parents_and_children() + return partition def set_parents_and_children(self) -> None: # Go through all nodes in a partition. @@ -388,10 +465,8 @@ def set_parents_and_children(self) -> None: # that partition is not the child of the current partition for p in self.partitions: if p != partition and n in p.nodes and node not in p.nodes: - if p not in partition.children: - partition.children.add(p) - if partition not in p.parents: - p.parents.add(partition) + partition.children.add(p) + p.parents.add(partition) return def reorganize_partitions(self) -> None: @@ -487,7 +562,7 @@ def is_embedding_node(node: Node) -> bool: total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) if total_size_of_input_nodes > available_mem_bytes: raise RuntimeError(node.target + 'is too large to fit into a device') - partition.nodes.add(node) + partition.add_node(node) partition.used_mem_bytes += total_size_of_input_nodes reset_partition_in_sparse_nn(partition, new_partition=False) # Set parents and children for each partition diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 000000000000..0d26080a7fd9 --- /dev/null +++ b/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,74 @@ +from typing import NamedTuple, Dict, List +from torch.fx.node import Node, map_arg +from torch.fx.experimental.Partitioner import Partition + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency: float + # Latency due to the computation + compute_latency: float + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency: float + # Sum of all nodes' compute latency on the critical path + compute_latency: float + # Latency of the critical path + overall_latency: float + +def get_latency_of_one_partition( + partition: Partition, + node_to_latency_mapping: Dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partiton and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> List[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: List[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {'placeholder', 'get_attr'}: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any([n in partition.nodes and n.op not in {'placeholder', 'get_attr'} for n in input_nodes]): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency = partition_latency.overall_latency + max(node_latency.compute_latency, node_latency.mem_latency) + # Update the mem latency of this path + mem_latency = partition_latency.mem_latency + node_latency.mem_latency + # Update the compute latency of this path + compute_latency = partition_latency.compute_latency + node_latency.compute_latency + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper(n, PartitionLatency(mem_latency, compute_latency, overall_latency)) + if new_partition_latency.overall_latency > max_latency.overall_latency: + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency(mem_latency, compute_latency, overall_latency) + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper(node, PartitionLatency(mem_latency=0., compute_latency=0., overall_latency=0.)) + if partition_latency.overall_latency > critical_path_latency.overall_latency: + critical_path_latency = partition_latency + return critical_path_latency diff --git a/torch/fx/graph.py b/torch/fx/graph.py index a9a7e5dc80c0..d76dd2e6ba22 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -383,8 +383,7 @@ def type_repr(o : Any): # repr() for inf and nan floating point values aren't parseable by # python as literals. Explicitly import the names from the `math` module. - import_strs = ['from math import inf, nan'] - import_strs.extend(f'import {name}' for name in sorted(modules_used)) + import_strs = [f'import {name}' for name in sorted(modules_used)] import_block = '\n'.join(import_strs) code = ''.join(body) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 6f72a29be184..24dde1ea13d4 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -4,6 +4,7 @@ from typing import Type, Dict, List, Any, Union from .graph import Graph import copy +import math # normal exec loses the source code, however we can patch # the linecache module to still recover it. @@ -28,7 +29,7 @@ def patched_getline(*args, **kwargs): linecache.getlines = patched_getline def _forward_from_src(src : str): - gbls: Dict[str, Any] = {} + gbls: Dict[str, Any] = {'inf': math.inf, 'nan': math.nan} exec_with_source(src, gbls) return gbls['forward'] diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index ab8a871adcf1..2865c97da22b 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -173,6 +173,21 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module)) orig_call = torch.nn.Module.__call__ + orig_getattr = torch.nn.Module.__getattr__ + + parameter_proxy_cache = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + def module_getattr_wrapper(mod, attr): + attr_val = orig_getattr(mod, attr) + if isinstance(attr_val, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if attr_val is p: + if n not in parameter_proxy_cache: + parameter_proxy_cache[n] = self.create_proxy('get_attr', n, (), {}) + return parameter_proxy_cache[n] + return attr_val def module_call_wrapper(mod, *args, **kwargs): def forward(*args, **kwargs): @@ -181,11 +196,14 @@ def forward(*args, **kwargs): return self.call_module(mod, forward, args, kwargs) try: + # Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427 + torch.nn.Module.__getattr__ = module_getattr_wrapper # type: ignore torch.nn.Module.__call__ = module_call_wrapper self.create_node('output', 'output', (self.create_arg(fn(*args)),), {}, type_expr=fn.__annotations__.get('return', None)) finally: torch.nn.Module.__call__ = orig_call + torch.nn.Module.__getattr__ = orig_getattr # type: ignore return self.graph # Symbolic tracing API diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 74ee0bad2ad6..b9120f52379e 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1010,7 +1010,6 @@ def check_unique(param): "TracedModules don't support parameter sharing between modules" ) id_set.add(param) - tmp_module.training = orig.training for name, param in orig._parameters.items(): @@ -1046,7 +1045,7 @@ def check_unique(param): self.__dict__["_name"] = type(orig).__name__ self.__dict__["_actual_script_module"] = script_module - for name in ("_parameters", "_buffers", "_modules"): + for name in ("_parameters", "_buffers", "_modules", "training"): delattr(self, name) def forward(self, *args, **kwargs): diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 074bb47faaac..ad0badf5eed9 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -139,3 +139,48 @@ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (tensor(3.7417), tensor(11.2250)) """) + +tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" +linalg.tensorsolve(input, other, dims=None, *, out=None) -> Tensor + +Computes a tensor ``x`` such that ``tensordot(input, x, dims=x.ndim) = other``. +The resulting tensor ``x`` has the same shape as ``input[other.ndim:]``. + +Supports real-valued and, only on the CPU, complex-valued inputs. + +.. note:: If :attr:`input` does not satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])`` + after (optionally) moving the dimensions using :attr:`dims`, then a RuntimeError will be thrown. + +Args: + input (Tensor): "left-hand-side" tensor, it must satisfy the requirement + ``prod(input.shape[other.ndim:]) == prod(input.shape[:other.ndim])``. + other (Tensor): "right-hand-side" tensor of shape ``input.shape[other.ndim]``. + dims (Tuple[int]): dimensions of :attr:`input` to be moved before the computation. + Equivalent to calling ``input = movedim(input, dims, range(len(dims) - input.ndim, 0))``. + If None (default), no dimensions are moved. + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> b = torch.randn(2 * 3, 4) + >>> x = torch.linalg.tensorsolve(a, b) + >>> x.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b) + True + + >>> a = torch.randn(6, 4, 4, 3, 2) + >>> b = torch.randn(4, 3, 2) + >>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2)) + >>> x.shape + torch.Size([6, 4]) + >>> a = a.permute(1, 3, 4, 0, 2) + >>> a.shape[b.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6) + True +""") diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 215fb0278dc6..6b0702581f2c 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -189,7 +189,8 @@ def embedding(input: Tensor, weight: Tensor, padding_idx: Optional[int] = ..., m def embedding_bag(input: Tensor, weight: Tensor, offsets: Optional[Tensor] = ..., max_norm: Optional[float] = ..., norm_type: float = ..., scale_grad_by_freq: bool = ..., mode: str = ..., - sparse: bool = ...) -> Tensor: ... + sparse: bool = ..., per_sample_weights: Optional[Tensor] = ..., + include_last_offset: bool = ...) -> Tensor: ... def batch_norm(input: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], weight: Optional[Tensor] = ..., bias: Optional[Tensor] = ..., training: bool = ..., diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index f5ca6deb5b19..e76e307d36a6 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -54,9 +54,11 @@ def __init__( def reset_running_stats(self) -> None: if self.track_running_stats: - self.running_mean.zero_() - self.running_var.fill_(1) - self.num_batches_tracked.zero_() + # running_mean/running_var/num_batches... are registerd at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[operator] + self.running_var.fill_(1) # type: ignore[operator] + self.num_batches_tracked.zero_() # type: ignore[operator] def reset_parameters(self) -> None: self.reset_running_stats() @@ -107,8 +109,8 @@ def forward(self, input: Tensor) -> Tensor: if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None - if self.num_batches_tracked is not None: - self.num_batches_tracked = self.num_batches_tracked + 1 + if self.num_batches_tracked is not None: # type: ignore + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average @@ -128,6 +130,8 @@ def forward(self, input: Tensor) -> Tensor: passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated @@ -487,6 +491,7 @@ def forward(self, input: Tensor) -> Tensor: exponential_average_factor = self.momentum if self.training and self.track_running_stats: + assert self.num_batches_tracked is not None self.num_batches_tracked = self.num_batches_tracked + 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / self.num_batches_tracked.item() @@ -508,6 +513,8 @@ def forward(self, input: Tensor) -> Tensor: used for normalization (i.e. in eval mode when buffers are not None). """ # If buffers are not to be tracked, ensure that they won't be updated + assert self.running_mean is None or isinstance(self.running_mean, torch.Tensor) + assert self.running_var is None or isinstance(self.running_var, torch.Tensor) running_mean = self.running_mean if not self.training or self.track_running_stats else None running_var = self.running_var if not self.training or self.track_running_stats else None diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index a0f9c9a19afa..b27fd644993f 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -52,6 +52,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) + assert self.running_mean is None or isinstance(self.running_mean, Tensor) + assert self.running_var is None or isinstance(self.running_var, Tensor) return F.instance_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training or not self.track_running_stats, self.momentum, self.eps) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 45f66086d600..ea2c3a8f453b 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -102,10 +102,10 @@ def extra_repr(self) -> str: # This class exists solely for Transformer; it has an annotation stating # that bias is never None, which appeases TorchScript class _LinearWithBias(Linear): - bias: Tensor + bias: Tensor # type: ignore def __init__(self, in_features: int, out_features: int) -> None: - super().__init__(in_features, out_features, bias=True) + super().__init__(in_features, out_features, bias=True) # type: ignore class Bilinear(Module): @@ -208,7 +208,8 @@ class LazyLinear(LazyModuleMixin, Linear): """ - cls_to_become = Linear + cls_to_become = Linear # type: ignore[assignment] + weight: UninitializedParameter def __init__(self, out_features: int, bias: bool = True) -> None: super().__init__(0, out_features, bias) @@ -218,7 +219,7 @@ def reset_parameters(self) -> None: if not self.has_uninitialized_params() and self.in_features != 0: super().reset_parameters() - def initialize_parameters(self, input) -> None: + def initialize_parameters(self, input) -> None: # type: ignore if self.has_uninitialized_params(): with torch.no_grad(): self.in_features = input.shape[-1] diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 734912684d8f..7a43fcc2ea2d 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -45,6 +45,10 @@ class MaxPool1d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the sliding window. This `link`_ has a nice visualization of the pooling parameters. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + Args: kernel_size: The size of the sliding window, must be > 0. stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`. @@ -104,6 +108,10 @@ class MaxPool2d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -174,6 +182,10 @@ class MaxPool3d(_MaxPoolNd): for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension @@ -474,6 +486,10 @@ class AvgPool1d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be an ``int`` or a one-element tuple. @@ -537,6 +553,10 @@ class AvgPool2d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be: - a single ``int`` -- in which case the same value is used for the height and width dimension @@ -614,6 +634,10 @@ class AvgPool3d(_AvgPoolNd): If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides for :attr:`padding` number of points. + Note: + When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding + or the input. Sliding windows that would start in the right padded region are ignored. + The parameters :attr:`kernel_size`, :attr:`stride` can either be: - a single ``int`` -- in which case the same value is used for the depth, height and width dimension diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 3589d4b815c9..8dd9c2c7a10d 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -99,8 +99,8 @@ class Embedding(Module): num_embeddings: int embedding_dim: int - padding_idx: int - max_norm: float + padding_idx: Optional[int] + max_norm: Optional[float] norm_type: float scale_grad_by_freq: bool weight: Tensor @@ -284,7 +284,7 @@ class EmbeddingBag(Module): num_embeddings: int embedding_dim: int - max_norm: float + max_norm: Optional[float] norm_type: float scale_grad_by_freq: bool weight: Tensor diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 0fceb2137a3b..7d64ed8a4174 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -329,7 +329,7 @@ class DistributedDataParallel(Module): Example:: >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') - >>> net = torch.nn.DistributedDataParallel(model, pg) + >>> net = torch.nn.parallel.DistributedDataParallel(model, pg) """ def __init__(self, module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, @@ -626,7 +626,7 @@ def no_sync(self): Example:: - >>> ddp = torch.nn.DistributedDataParallel(model, pg) + >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg) >>> with ddp.no_sync(): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 2070614f2cb0..7e9e17eebf83 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -11,5 +11,5 @@ class Parameter(Tensor): class UninitializedParameter(Tensor): def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... - def materialize(self, shape: Tuple[int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... + def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... ... diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 4a4a46bf780a..f220aed075c1 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -91,7 +91,8 @@ def from_float(cls, mod): from torch.quantization.qconfig import default_dynamic_qconfig weight_observer = default_dynamic_qconfig.weight() dtype = weight_observer.dtype - assert dtype in [torch.qint8, torch.float16], 'The only supported dtypes for dynamic quantized linear are qint8 and float16' + assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \ + "dynamic quantized linear are qint8 and float16 got: {}".format(dtype) weight_observer(mod.weight) if dtype == torch.qint8: qweight = _quantize_weight(mod.weight.float(), weight_observer) diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index a40a3e3fbcac..a064c72dda98 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -7,7 +7,7 @@ from .normalization import LayerNorm, GroupNorm, InstanceNorm1d, \ InstanceNorm2d, InstanceNorm3d from .conv import Conv1d, Conv2d, Conv3d -from .conv import ConvTranspose1d, ConvTranspose2d +from .conv import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d from .linear import Linear from .embedding_ops import Embedding, EmbeddingBag @@ -91,6 +91,7 @@ def from_float(mod): 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', + 'ConvTranspose3d', 'DeQuantize', 'ELU', 'Embedding', diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 31c914d2bf35..e9f4a4c701eb 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -606,9 +606,6 @@ class ConvTranspose2d(_ConvTransposeNd): For details on input arguments, parameters, and implementation see :class:`~torch.nn.ConvTranspose2d`. - .. note:: Currently only the QNNPACK engine is implemented. - Please, set the `torch.backends.quantized.engine = 'qnnpack'` - For special notes, please, see :class:`~torch.nn.quantized.Conv2d` Attributes: @@ -620,6 +617,7 @@ class ConvTranspose2d(_ConvTransposeNd): Examples:: + >>> # QNNPACK or FBGEMM as backend >>> torch.backends.quantized.engine = 'qnnpack' >>> # With square kernels and equal stride >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) @@ -684,3 +682,88 @@ def forward(self, input): raise ValueError("Input shape must be `(N, C, H, W)`!") return ops.quantized.conv_transpose2d( input, self._packed_params, self.scale, self.zero_point) + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> torch.backends.quantized.engine = 'fbgemm' + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose3d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros'): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super(ConvTranspose3d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode) + + def _get_name(self): + return 'QuantizedConvTranpose3d' + + def set_weight_bias(self, w, b): + # type: (torch.Tensor, Optional[torch.Tensor]) -> None + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 255c15b9da4a..421c23ebed6e 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -44,7 +44,7 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM model (torch.nn.Module): the model to be exported. args (tuple of arguments or torch.Tensor): the inputs to the model, e.g., such that ``model(*args)`` is a valid - invocation of the model. Any non-Tensor arguments will + invocation of the model. Any non-Tensor arguments (including None) will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in args. If args is a Tensor, this is equivalent diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 5a266a429965..f6acc4120dc2 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -581,6 +581,26 @@ def mm(g, self, other): return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) +def index(g, self, index): + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", self, index, operator_s="index") + + if sym_help._is_packed_list(index): + indices = sym_help._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not sym_help._is_none(index) and (index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): + from torch.onnx.symbolic_opset9 import nonzero + index = nonzero(g, index) + return g.op('GatherND', self, index) + from torch.onnx.symbolic_opset9 import index as index_opset9 + return index_opset9(g, self, index) + + def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, 'i') if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: diff --git a/torch/overrides.py b/torch/overrides.py index 8224944d9ff6..9a91866e3f5b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -279,6 +279,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.clip: lambda input, min=None, max=None, out=None: -1, torch.clamp_min: lambda input, min, out=None: -1, torch.clamp_max: lambda input, max, out=None: -1, + torch.column_stack: lambda tensors, out=None: -1, torch.clone: lambda input: -1, torch.combinations: lambda input, r=2, with_replacement=False: -1, torch.complex: lambda real, imag: -1, @@ -388,6 +389,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.hstack: lambda tensors, out=None: -1, torch.hypot: lambda input, other, out=None: -1, torch.ifft: lambda input, signal_ndim, normalized=False: -1, + torch.igamma: lambda input, other, out=None: -1, torch.imag: lambda input, out=None: -1, torch.index_add: lambda input, dim, index, source: -1, torch.index_copy: lambda input, dim, index, source: -1, @@ -693,6 +695,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.roll: lambda input, shifts, dims=None: -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1, torch.round: lambda input, out=None: -1, + torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, torch.rsqrt: lambda input, out=None: -1, @@ -738,6 +741,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.tan: lambda input, out=None: -1, torch.tanh: lambda input, out=None: -1, torch.tensordot: lambda a, b, dims=2: -1, + torch.linalg.tensorsolve: lambda a, b, dims=None: -1, torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, torch.threshold: lambda input, threshold, value, inplace=False: -1, torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, @@ -840,7 +844,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.apply_: lambda self, callable: -1, Tensor.as_strided: lambda self, size, stride: -1, Tensor.as_strided_: lambda self, size, stride: -1, - Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False: -1, + Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1, Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index d6daa79fae53..24e929b5fc8e 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -6,7 +6,7 @@ from .stubs import * from .quant_type import * from .quantize_jit import * -from .quantize_fx import * +# from .quantize_fx import * from .quantization_mappings import * from .fuser_method_mappings import * @@ -26,8 +26,8 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on TorchScript 'quantize_jit', 'quantize_dynamic_jit', # Top level API for graph mode quantization on GraphModule(torch.fx) - 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx - 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', 'QuantType', 'quant_type_to_str', # quantization type # custom module APIs 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index d537630c2406..4c92192dc5be 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -9,15 +9,18 @@ # Fusion Patterns # --------------------- -@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d)) +@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d)) +@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d)) @register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) class ConvBNReLUFusion(): def __init__(self, quantizer, node): super().__init__() @@ -66,7 +69,8 @@ def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None): fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) - setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list)) + fused = fuser_method(*op_list) + setattr(quantizer.modules[conv_parent_name], conv_name, fused) # TODO: do we need to make sure bn is only used once? if self.bn_node is not None: @@ -77,8 +81,6 @@ def fuse(self, quantizer, load_arg, fuse_custom_config_dict=None): @register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear)) -@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm1d)) -@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm1d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index c588354a3192..e69c6c5ea33f 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -25,6 +25,7 @@ weight_is_quantized, weight_dtype, get_linear_prepack_op_for_dtype, + get_qconfig_dtypes, ) from abc import ABC, abstractmethod @@ -283,7 +284,21 @@ def __init__(self, quantizer, node): self.linear = quantizer.modules[self.linear_node.target] def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + # Supported combinations are: + # quant_type | activation (compute_type) | weight + # static quint8 qint8 + # dynamic float32 (quint8) qint8 + # weight_only float32 float16 + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.quint8, torch.qint8, None), + (torch.float32, torch.qint8, torch.quint8), + (torch.float16, torch.float16, None), + ] qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + assert dtypes in supported_dtypes, "qconfig dtype pair not supported:" \ + " {}, supported dtypes are: {}".format(dtypes, supported_dtypes) activation_statically_quantized = activation_is_statically_quantized(qconfig) # TODO: debug option for linear module if self.linear_node.op == 'call_module': @@ -423,11 +438,22 @@ def __init__(self, quantizer, node): super().__init__(quantizer, node) def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None): + # Supported combinations are: + # quant_type | activation (compute_type) | weight + # weight_only | float32 (torch.uint8) | quint8 + # weight_only | float32 (torch.uint8) | quint4x2 + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.float32, torch.quint8, torch.quint8), + (torch.float32, torch.quint4x2, torch.quint8), + ] assert node.op == 'call_module' emb_node = node emb = quantizer.modules[emb_node.target] qconfig = quantizer.qconfig_map[node.name] - assert not activation_is_statically_quantized(qconfig) + dtypes = get_qconfig_dtypes(qconfig) + assert dtypes in supported_dtypes, "qconfig dtype pair not supported:" \ + " {}, supported dtypes are: {}".format(dtypes, supported_dtypes) qemb = get_static_quant_module_class(type(emb)) quantized = qemb.from_float(emb) parent_name, name = _parent_name(emb_node.target) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 5536f5ef58b0..732f2efdedfe 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -49,7 +49,6 @@ from collections import OrderedDict import warnings -import copy import re from typing import Optional @@ -562,7 +561,7 @@ def _run_weight_observers(self, observed): weight_observer_module() return - def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict=None, is_standalone_module=False): + def _convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. For standalone module: the inputs will be quantized by parent module, @@ -575,8 +574,6 @@ def _convert(self, model, inplace=False, debug=False, convert_custom_config_dict if convert_custom_config_dict is None: convert_custom_config_dict = {} self.restore_state(model) - if not inplace: - model = copy.deepcopy(model) # always run weight observers in the top level forward method # for dynamic quant ops or weight only quant ops self._run_weight_observers(model) @@ -824,8 +821,8 @@ def load_arg(a): quantized = GraphModule(quantized_root, folded_graph) return quantized - def convert(self, model, inplace=False, debug=False, convert_custom_config_dict=None, is_standalone_module=False): - quantized = self._convert(model, inplace, debug, convert_custom_config_dict, is_standalone_module) + def convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False): + quantized = self._convert(model, debug, convert_custom_config_dict, is_standalone_module) if not debug: quantized = self._fold_weight(quantized) return quantized diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 366970cec4c0..57f557e09048 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -208,6 +208,16 @@ def weight_is_quantized(qconfig): """ return weight_dtype(qconfig) in [torch.quint8, torch.qint8] +def get_qconfig_dtypes(qconfig): + r""" returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_compute_dtype) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None + return (activation.dtype, weight.dtype, compute_dtype) + def get_quant_type(qconfig): assert qconfig is not None activation = qconfig.activation() diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 1794c3ac5a2d..93043559bf48 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -103,10 +103,8 @@ def _prepare_standalone_module_fx(model, qconfig_dict, prepare_custom_config_dic custom module is observed or not """ - torch._C._log_api_usage_once("quantization_api.quantize_fx._prepare_standalone_module_fx") return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) - def fuse_fx(model, fuse_custom_config_dict=None): r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. Fusion rules are defined in torch.quantization.fx.fusion_pattern.py @@ -280,19 +278,17 @@ def train_loop(model, train_data): 'train mode' return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) -def _convert_fx(graph_module, inplace, debug, convert_custom_config_dict=None, is_standalone_module=False): +def _convert_fx(graph_module, debug, convert_custom_config_dict=None, is_standalone_module=False): """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` """ _check_is_graph_module(graph_module) quantizer = Quantizer() - return quantizer.convert(graph_module, inplace, debug, convert_custom_config_dict, is_standalone_module) + return quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module) -def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_dict=None): +def convert_fx(graph_module, debug=False, convert_custom_config_dict=None): r""" Convert a calibrated or trained model to a quantized model Args: `graph_module`: A prepared and calibrated/trained model (GraphModule) - `inplace`: flag for carry out model transformations in-place, - the original module is mutated `debug`: flag for producing a debug friendly model (preserve weight attribute) `convert_custom_config_dict`: dictionary for custom configurations for convert function: convert_custom_config_dict = { @@ -334,9 +330,9 @@ def convert_fx(graph_module, inplace=False, debug=False, convert_custom_config_d ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") - return _convert_fx(graph_module, inplace, debug, convert_custom_config_dict) + return _convert_fx(graph_module, debug, convert_custom_config_dict) -def _convert_standalone_module_fx(graph_module, inplace=False, debug=False, convert_custom_config_dict=None): +def _convert_standalone_module_fx(graph_module, debug=False, convert_custom_config_dict=None): r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model @@ -347,5 +343,4 @@ def _convert_standalone_module_fx(graph_module, inplace=False, debug=False, conv A quantized standalone module which accepts quantized input(if needed) and produces quantized output (if needed). """ - torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_standalone_module_fx") - return _convert_fx(graph_module, inplace, debug, convert_custom_config_dict, is_standalone_module=True) + return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True) diff --git a/torch/tensor.py b/torch/tensor.py index 1f6dce4ecb34..64e7d9ee44c0 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -178,7 +178,7 @@ def __repr__(self): # All strings are unicode in Python 3. return torch._tensor_str._str(self) - def backward(self, gradient=None, retain_graph=None, create_graph=False): + def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None): r"""Computes the gradient of current tensor w.r.t. graph leaves. The graph is differentiated using the chain rule. If the tensor is @@ -213,6 +213,11 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. + inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be + accumulated into ``.grad``. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that were + used to compute the attr::tensors. All the provided inputs must be leaf + Tensors. """ relevant_args = (self,) from torch.overrides import has_torch_function, handle_torch_function @@ -223,8 +228,9 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): self, gradient=gradient, retain_graph=retain_graph, - create_graph=create_graph) - torch.autograd.backward(self, gradient, retain_graph, create_graph) + create_graph=create_graph, + inputs=inputs) + torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) def register_hook(self, hook): r"""Registers a backward hook. diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index a896c4ca0af5..8564db80d87c 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -9,7 +9,8 @@ import os import torch from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ - skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN + skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing import \ (get_all_dtypes) @@ -166,9 +167,6 @@ # See below for how this list is populated. If you're adding a device type # you should check if it's available and (if it is) add it to this list. -# set type to List[Any] due to mypy list-of-union issue: -# https://github.com/python/mypy/issues/3351 -device_type_test_bases: List[Any] = list() def _construct_test_name(test_name, op, device_type, dtype): if op is not None: @@ -361,9 +359,25 @@ def setUpClass(cls): # Adds available device-type-specific test base classes -device_type_test_bases.append(CPUTestBase) -if torch.cuda.is_available(): - device_type_test_bases.append(CUDATestBase) +def get_device_type_test_bases(): + # set type to List[Any] due to mypy list-of-union issue: + # https://github.com/python/mypy/issues/3351 + test_bases: List[Any] = list() + + if IS_SANDCASTLE or IS_FBCODE: + if IS_REMOTE_GPU: + test_bases.append(CUDATestBase) + else: + test_bases.append(CPUTestBase) + else: + test_bases.append(CPUTestBase) + if torch.cuda.is_available(): + test_bases.append(CUDATestBase) + + return test_bases + + +device_type_test_bases = get_device_type_test_bases() # Note [How to extend DeviceTypeTestBase to add new test device] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 48b24f1ae499..24ad7531700f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -274,8 +274,11 @@ def sample_inputs(self, device, dtype, requires_grad=False): test_inplace_grad=False), UnaryUfuncInfo('cos', ref=np.cos, - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), handles_large_floats=False, + promotes_integers_to_float=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', @@ -380,6 +383,10 @@ def sample_inputs(self, device, dtype, requires_grad=False): )), UnaryUfuncInfo('tan', ref=np.tan, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + promotes_integers_to_float=True, skips=( SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), @@ -531,6 +538,7 @@ def method_tests(): ('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), ('add', (S, S, S), (3.14,), 'constant', (True,)), ('add', (), (3.14,), 'scalar_constant', (True,)), + ('add', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)), ('asinh', (S, S, S), NO_ARGS, ''), ('asinh', (), NO_ARGS, 'scalar'), ('atanh', torch.rand(S, S, S), NO_ARGS, ''), @@ -545,6 +553,7 @@ def method_tests(): ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), ('sub', (S, S, S), (3.14,), 'constant', (True,)), ('sub', (), (3.14,), 'scalar_constant', (True,)), + ('sub', (S, S, S), (3.14j,), 'complex_scalar_constant', (True,)), ('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')), ('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')), ('mul', (S, S, S), ((S, S, S),), '', (True,)), diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 440e59cf9174..1b2b4165b044 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -12,22 +12,23 @@ from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_dynamic_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic + get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType from torch.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, get_default_qconfig_propagation_list, get_default_qat_module_mappings, ) -# symbolic trace -from torch.fx import symbolic_trace - -# graph mode quantization based on fx -from torch.quantization import ( - QuantType, - prepare_fx, - prepare_qat_fx, - convert_fx, -) + +try: + # graph mode quantization based on fx + from torch.quantization.quantize_fx import ( + prepare_fx, + prepare_qat_fx, + convert_fx, + ) + HAS_FX = True +except ImportError: + HAS_FX = False import copy import io @@ -599,77 +600,77 @@ def printGraphModule(self, graph_module, print_str=True): print(str_to_print) return str_to_print - def checkGraphModeFxOp(self, model, inputs, quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - debug=False, - print_debug_info=False, - custom_qconfig=None): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node + if HAS_FX: + def checkGraphModeFxOp(self, model, inputs, quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + debug=False, + print_debug_info=False, + custom_qconfig=None): + """ Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + """ + # TODO: make img_data a single example instead of a list + if type(inputs) == list: + inputs = inputs[0] - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - """ - # TODO: make img_data a single example instead of a list - if type(inputs) == list: - inputs = inputs[0] - if custom_qconfig is None: if quant_type == QuantType.QAT: qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) + model.train() elif quant_type == QuantType.STATIC: qconfig = get_default_qconfig(torch.backends.quantized.engine) + model.eval() else: qconfig = default_dynamic_qconfig - else: - qconfig = custom_qconfig + model.eval() - if quant_type == QuantType.QAT: - model.train() - else: - model.eval() + # overwrite qconfig with custom_qconfig + if custom_qconfig is not None: + qconfig = custom_qconfig - original = symbolic_trace(model) - if quant_type == QuantType.QAT: - prepare = prepare_qat_fx - else: - prepare = prepare_fx - - qconfig_dict = {'': qconfig} - prepared = prepare(original, qconfig_dict) - if not quant_type == QuantType.DYNAMIC: - prepared(*inputs) - qgraph = convert_fx(prepared) - qgraph_debug = convert_fx(prepared, debug=True) - result = qgraph(*inputs) - result_debug = qgraph_debug(*inputs) - - qgraph_to_check = qgraph_debug if debug else qgraph - if print_debug_info: - print() - print('quant type:', quant_type) - print('origianl graph module:', type(model)) - self.printGraphModule(original) - print() - print('quantized graph module:', type(qgraph_to_check)) - self.printGraphModule(qgraph_to_check) - print() - self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) + if quant_type == QuantType.QAT: + prepare = prepare_qat_fx + else: + prepare = prepare_fx + + qconfig_dict = {'': qconfig} + prepared = prepare(model, qconfig_dict) + if not quant_type == QuantType.DYNAMIC: + prepared(*inputs) + prepared_copy = copy.deepcopy(prepared) + qgraph = convert_fx(prepared) + qgraph_debug = convert_fx(prepared_copy, debug=True) + result = qgraph(*inputs) + result_debug = qgraph_debug(*inputs) + + qgraph_to_check = qgraph_debug if debug else qgraph + if print_debug_info: + print() + print('quant type:', quant_type) + print('original model:', model) + print() + print('quantized model:', qgraph_to_check) + self.printGraphModule(qgraph_to_check) + print() + self.checkGraphModuleNodes( + qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bfb61c0a6981..92c5c52cf947 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -60,6 +60,7 @@ IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle' IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1' +IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1' class ProfilingMode(Enum): LEGACY = 1 diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 305b0fcb82bf..84768496b5ff 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -20,7 +20,7 @@ skip_if_lt_x_gpu, skip_if_rocm, ) -from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE +from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @@ -619,35 +619,38 @@ def test_ddp_dist_autograd_local_vs_remote(self): rank=self.rank, ) - remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 5, False) - ) - layer1 = nn.Linear(10, 5, False) - # Start with the same parameters for remote and local - layer1.weight = remote_layer1.module_rref.to_here().weight - - # Run local case. - layer2 = nn.Linear(5, 1) - inputs = torch.rand((10, 10)) - ddp_model = DistributedDataParallel(layer2) - loss = ddp_model(layer1(inputs)).sum() - loss.backward() - - # Run remote case. - with dist_autograd.context() as context_id: - loss = ddp_model(remote_layer1(inputs)).sum() - dist_autograd.backward(context_id, [loss]) - grads_dict = dist_autograd.get_gradients(context_id) - dist.barrier() - self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) - self.assertEqual( - layer1.weight.grad, - rpc.rpc_sync( - "worker0", - DdpComparisonTest.get_remote_grads, - args=(remote_layer1.module_rref, context_id), - ), + # Use two different remote device input string, w/ and w/o the default + # device string "cpu", respectively. + for remote_device in ["worker0/cpu", "worker0"]: + remote_layer1 = RemoteModule( + remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False) ) + layer1 = nn.Linear(10, 5, False) + # Start with the same parameters for remote and local + layer1.weight = remote_layer1.module_rref.to_here().weight + + # Run local case. + layer2 = nn.Linear(5, 1) + inputs = torch.rand((10, 10)) + ddp_model = DistributedDataParallel(layer2) + loss = ddp_model(layer1(inputs)).sum() + loss.backward() + + # Run remote case. + with dist_autograd.context() as context_id: + loss = ddp_model(remote_layer1(inputs)).sum() + dist_autograd.backward(context_id, [loss]) + grads_dict = dist_autograd.get_gradients(context_id) + dist.barrier() + self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight]) + self.assertEqual( + layer1.weight.grad, + rpc.rpc_sync( + "worker0", + DdpComparisonTest.get_remote_grads, + args=(remote_layer1.module_rref, context_id), + ), + ) @skip_if_lt_x_gpu(NUM_TRAINERS) @requires_nccl() @@ -667,7 +670,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): ) remote_layer1 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(10, 7, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False) ) layer1 = nn.Linear(10, 7, False) # Start with the same parameters for remote and local @@ -677,7 +680,7 @@ def test_ddp_dist_autograd_local_vs_remote_gpu(self): ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank]) remote_layer3 = RemoteModule( - "worker0", device="cpu", module_cls=nn.Linear, args=(5, 3, False) + remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False) ) layer3 = nn.Linear(5, 3, False) # Start with the same parameters for remote and local diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f304e37389b5..2dc0d37c6975 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -1001,7 +1001,6 @@ def test_broadcast_full_group(self): "Only NCCL backend supports high priority stream", ) @skip_if_no_gpu - @skip_if_rocm def test_nccl_high_priority_stream(self): group, _, rank = self._init_global_test() rank_to_GPU = self._init_multigpu_helper() @@ -3140,6 +3139,13 @@ def validate_global_samples(local_num_samples): @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_allgather_object(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + gather_objects = collectives_object_test_list output_gathered = [None for _ in range(dist.get_world_size())] dist.all_gather_object( @@ -3194,7 +3200,10 @@ class Bar: def test_nccl_gather_object_err(self): output_gathered = [None for _ in range(dist.get_world_size())] gather_on_rank = 0 + # Case where rank != GPU device. my_rank = dist.get_rank() + next_rank = (my_rank + 1) % dist.get_world_size() + torch.cuda.set_device(next_rank) with self.assertRaisesRegex( RuntimeError, "ProcessGroupNCCL does not support gather" ): @@ -3665,6 +3674,13 @@ def test_ddp_uneven_inputs_replicated_error(self): @require_backend({"nccl", "gloo"}) @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]) def test_broadcast_object_list(self): + # Only set device for NCCL backend since it must use GPUs. + backend = os.environ["BACKEND"] + if backend == "nccl": + # Case where rank != GPU device. + next_rank = (self.rank + 1) % int(self.world_size) + torch.cuda.set_device(next_rank) + src_rank = 0 objects = collectives_object_test_list if self.rank == src_rank else [None for _ in collectives_object_test_list] @@ -3800,6 +3816,39 @@ def forward(self, x): else: ddp(inp).sum().backward() + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_shared_grad_acc_unused_params(self): + # When find_unused_parameters=True, ensure we mark unused parameters + # even if they share gradient accumulators. + class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + # net1, bias, and net1.bias are all unused params. + self.net1 = nn.Linear(10, 5, bias=False) + self.bias = nn.Parameter(torch.zeros(5)) + # net1.bias and self.bias are names for the same underlying + # parameter, so they share the same grad acc. This caused + # the bug reported in https://github.com/pytorch/pytorch/issues/41324. + self.net1.bias = self.bias + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(x) + + torch.cuda.set_device(self.rank) + model = ToyModel().to(torch.cuda.current_device()) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.rank], find_unused_parameters=True + ) + inp = torch.randn(20, 10, device=self.rank) + for i in range(6): + out = ddp_model(inp) + loss = out.sum() + loss.backward() + @require_backend({"gloo", "nccl"}) @require_backends_available({"gloo", "nccl"}) @skip_if_lt_x_gpu(2) diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index da81b3b16e53..d6b3d816fe68 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -78,7 +78,7 @@ def world_size(self): # Override setting in RpcAgentTestFixture return 2 @staticmethod - def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): + def _create_remote_module_iter(remote_device, modes=None): if modes is None: modes = ModuleCreationMode.__members__.values() @@ -86,15 +86,12 @@ def _create_remote_module_iter(dst_worker_name, device="cpu", modes=None): kwargs = dict(first_kwarg=2) if ModuleCreationMode.MODULE_CTOR in modes: - remote_module = RemoteModule( - dst_worker_name, device, MyModule, args, kwargs - ) + remote_module = RemoteModule(remote_device, MyModule, args, kwargs) yield remote_module if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: remote_module = _RemoteModule( - dst_worker_name, - device, + remote_device, create_scripted_module, args, kwargs, @@ -108,6 +105,7 @@ def test_bad_module(self): if self.rank != 0: return dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + remote_device = "{}/cpu".format(dst_worker_name) args = (1,) kwargs = dict(first_kwarg=2) @@ -115,13 +113,13 @@ def test_bad_module(self): ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) with self.assertRaisesRegex( ValueError, r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,", ): - RemoteModule(dst_worker_name, "cpu", BadModule, args, kwargs) + RemoteModule(remote_device, BadModule, args, kwargs) @dist_utils.dist_init def test_forward_async(self): @@ -227,7 +225,7 @@ def test_valid_device(self): dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) for remote_module in self._create_remote_module_iter( - dst_worker_name, device="cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] + "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] ): device = rpc.rpc_sync( dst_worker_name, remote_device, (remote_module.module_rref,) @@ -249,8 +247,7 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="foo", + "{}/foo".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -260,8 +257,7 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="cuda:100", + "{}/cuda:100".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], ) ) @@ -269,9 +265,8 @@ def test_invalid_devices(self): with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): list( self._create_remote_module_iter( - dst_worker_name, + "{}/cpu2".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR], - device="cpu2", ) ) @@ -280,8 +275,48 @@ def test_invalid_devices(self): ): list( self._create_remote_module_iter( - dst_worker_name, - device="cpu:2", + "{}/cpu:2".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): + list( + self._create_remote_module_iter( + "{}/".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "{}/cuda:0/cuda:1".format(dst_worker_name), + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/' cannot be empty. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "/", + modes=[ModuleCreationMode.MODULE_CTOR], + ) + ) + + with self.assertRaisesRegex( + RuntimeError, + r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '/'", + ): + list( + self._create_remote_module_iter( + "/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR], ) ) diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 982e212df216..e017dda2d4dd 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -54,11 +54,19 @@ class Timer(object): In addition to wall times, Timer can run a statement under Callgrind and report instructions executed. + Directly analogous to `timeit.Timer` constructor arguments: + + `stmt`, `setup`, `timer`, `globals` + + PyTorch Timer specific constructor arguments: + + `label`, `sub_label`, `description`, `env`, `num_threads` + Arguments: - Directly analogous to `timeit.Timer` constructor arguments: - ----------------------------------------------------------------------- stmt: Code snippet to be run in a loop and timed. + setup: Optional setup code. Used to define variables used in `stmt` + timer: Callable which returns the current time. If PyTorch was built without CUDA or there is no GPU present, this defaults to @@ -70,11 +78,9 @@ class Timer(object): executed. This is the other method for providing variables which `stmt` needs. - PyTorch Timer specific constructor arguments: - ----------------------------------------------------------------------- label: String which summarizes `stmt`. For instance, if `stmt` is - "torch.nn.functional.relu(torch.add(x, 1, out=out))" + "torch.nn.functional.relu(torch.add(x, 1, out=out))" one might set label to "ReLU(x + 1)" to improve readability. sub_label: @@ -82,20 +88,22 @@ class Timer(object): with identical stmt or label. For instance, in our example above sub_label might be "float" or "int", so that it is easy to differentiate: - "ReLU(x + 1): (float)" - "ReLU(x + 1): (int)" + "ReLU(x + 1): (float)" + + "ReLU(x + 1): (int)" when printing Measurements or summarizing using `Compare`. description: String to distinguish measurements with identical label and sub_label. The principal use of `description` is to signal to `Compare` the columns of data. For instance one might set it - based on the input size to create a table of the form: + based on the input size to create a table of the form: :: + + | n=1 | n=4 | ... + ------------- ... + ReLU(x + 1): (float) | ... | ... | ... + ReLU(x + 1): (int) | ... | ... | ... - | n=1 | n=4 | ... - ------------- ... - ReLU(x + 1): (float) | ... | ... | ... - ReLU(x + 1): (int) | ... | ... | ... using `Compare`. It is also included when printing a Measurement. @@ -269,23 +277,24 @@ def blocked_autorange( ) -> common.Measurement: """Measure many replicates while keeping timer overhead to a minimum. - At a high level, blocked_autorange executes the following pseudo-code: - ``` - `setup` + At a high level, blocked_autorange executes the following pseudo-code:: + + `setup` - total_time = 0 - while total_time < min_run_time - start = timer() - for _ in range(block_size): - `stmt` - total_time += (timer() - start) - ``` + total_time = 0 + while total_time < min_run_time + start = timer() + for _ in range(block_size): + `stmt` + total_time += (timer() - start) Note the variable `block_size` in the inner loop. The choice of block size is important to measurement quality, and must balance two competing objectives: + 1) A small block size results in more replicates and generally better statistics. + 2) A large block size better amortizes the cost of `timer` invocation, and results in a less biased measurement. This is important because CUDA syncronization time is non-trivial diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 4948e6e33099..afd654c6a85b 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -50,7 +50,7 @@ def _find_cuda_home() -> Optional[str]: if not os.path.exists(cuda_home): cuda_home = None if cuda_home and not torch.cuda.is_available(): - print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home)) + print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'") return cuda_home def _find_rocm_home() -> Optional[str]: @@ -72,7 +72,7 @@ def _find_rocm_home() -> Optional[str]: if not os.path.exists(rocm_home): rocm_home = None if rocm_home and torch.version.hip is None: - print("No ROCm runtime is found, using ROCM_HOME='{}'".format(rocm_home)) + print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'") return rocm_home @@ -275,13 +275,13 @@ def check_compiler_abi_compatibility(compiler) -> bool: version = (0, 0, 0) if match is None else match.groups() except Exception: _, error, _ = sys.exc_info() - warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error)) + warnings.warn(f'Error checking compiler version for {compiler}: {error}') return False if tuple(map(int, version)) >= minimum_required_version: return True - compiler = '{} {}'.format(compiler, ".".join(version)) + compiler = f'{compiler} {".".join(version)}' warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) return False @@ -364,6 +364,11 @@ def build_extensions(self) -> None: extension.extra_compile_args[ext] = [] self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') + # See note [Pybind11 ABI constants] + for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + val = getattr(torch._C, f"_PYBIND11_{name}") + if val is not None and not IS_WINDOWS: + self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') self._define_torch_extension_name(extension) self._add_gnu_cpp_abi_flag(extension) @@ -715,7 +720,7 @@ def _define_torch_extension_name(self, extension): # as the library name names = extension.name.split('.') name = names[-1] - define = '-DTORCH_EXTENSION_NAME={}'.format(name) + define = f'-DTORCH_EXTENSION_NAME={name}' self._add_compile_flag(extension, define) def _add_gnu_cpp_abi_flag(self, extension): @@ -1102,9 +1107,7 @@ def load_inline(name, # Make the function docstring the same as the function name. functions = dict((f, f) for f in functions) elif not isinstance(functions, dict): - raise ValueError( - "Expected 'functions' to be a list or dict, but was {}".format( - type(functions))) + raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") for function_name, docstring in functions.items(): if with_pytorch_error_handling: module_def.append( @@ -1170,9 +1173,9 @@ def _jit_compile(name, ) if version > 0: if version != old_version and verbose: - print('The input conditions for extension module {} have changed. '.format(name) + - 'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name)) - name = '{}_v{}'.format(name, version) + print(f'The input conditions for extension module {name} have changed. ' + + f'Bumping to version {version} and re-building as {name}_v{version}...') + name = f'{name}_v{version}' if version != old_version: baton = FileBaton(os.path.join(build_directory, 'lock')) @@ -1205,7 +1208,7 @@ def _jit_compile(name, baton.wait() elif verbose: print('No modifications detected for re-loaded extension ' - 'module {}, skipping build step...'.format(name)) + f'module {name}, skipping build step...') if verbose: print(f'Loading extension module {name}...') @@ -1292,11 +1295,11 @@ def _write_ninja_file_and_build_library( with_cuda=with_cuda) if verbose: - print('Building extension module {}...'.format(name)) + print(f'Building extension module {name}...') _run_ninja_build( build_directory, verbose, - error_prefix="Error building extension '{}'".format(name)) + error_prefix=f"Error building extension '{name}'") def is_ninja_available(): @@ -1342,10 +1345,10 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') extra_ldflags.append('torch.lib') extra_ldflags.append('torch_python.lib') - extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path)) - extra_ldflags.append('/LIBPATH:{}'.format(lib_path)) + extra_ldflags.append(f'/LIBPATH:{python_lib_path}') + extra_ldflags.append(f'/LIBPATH:{lib_path}') else: - extra_ldflags.append('-L{}'.format(lib_path)) + extra_ldflags.append(f'-L{lib_path}') extra_ldflags.append('-lc10') if with_cuda: extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') @@ -1359,19 +1362,18 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose): if verbose: print('Detected CUDA files, patching ldflags') if IS_WINDOWS: - extra_ldflags.append('/LIBPATH:{}'.format( - _join_cuda_home('lib/x64'))) + extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}') extra_ldflags.append('cudart.lib') if CUDNN_HOME is not None: extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64')) elif not IS_HIP_EXTENSION: - extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64'))) + extra_ldflags.append(f'-L{_join_cuda_home("lib64")}') extra_ldflags.append('-lcudart') if CUDNN_HOME is not None: - extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64'))) + extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}') elif IS_HIP_EXTENSION: assert ROCM_VERSION is not None - extra_ldflags.append('-L{}'.format(_join_rocm_home('lib'))) + extra_ldflags.append(f'-L{_join_rocm_home("lib")}') extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc') return extra_ldflags @@ -1421,7 +1423,7 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: # If not given, determine what's needed for the GPU that can be found if not _arch_list: capability = torch.cuda.get_device_capability() - arch_list = ['{}.{}'.format(capability[0], capability[1])] + arch_list = [f'{capability[0]}.{capability[1]}'] else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') @@ -1434,12 +1436,12 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: flags = [] for arch in arch_list: if arch not in valid_arch_strings: - raise ValueError("Unknown CUDA arch ({}) or GPU not supported".format(arch)) + raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") else: num = arch[0] + arch[2] - flags.append('-gencode=arch=compute_{},code=sm_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') if arch.endswith('+PTX'): - flags.append('-gencode=arch=compute_{},code=compute_{}'.format(num, num)) + flags.append(f'-gencode=arch=compute_{num},code=compute_{num}') return list(set(flags)) @@ -1466,8 +1468,7 @@ def _get_build_directory(name: str, verbose: bool) -> str: root_extensions_directory = get_default_build_root() if verbose: - print('Using {} as PyTorch extensions root...'.format( - root_extensions_directory)) + print(f'Using {root_extensions_directory} as PyTorch extensions root...') build_directory = os.path.join(root_extensions_directory, name) if not os.path.exists(build_directory): @@ -1483,7 +1484,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]: max_jobs = os.environ.get('MAX_JOBS') if max_jobs is not None and max_jobs.isdigit(): if verbose: - print('Using envvar MAX_JOBS ({}) as the number of workers...'.format(max_jobs)) + print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...') return int(max_jobs) if verbose: print('Allowing ninja to set a default number of workers... ' @@ -1550,7 +1551,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # `error` is a CalledProcessError (which has an `ouput`) attribute, but # mypy thinks it's Optional[BaseException] and doesn't narrow if hasattr(error, 'output') and error.output: # type: ignore - message += ": {}".format(error.output.decode()) # type: ignore + message += f": {error.output.decode()}" # type: ignore raise RuntimeError(message) from e @@ -1592,10 +1593,28 @@ def _write_ninja_file_to_build_library(path, user_includes += system_includes system_includes.clear() - common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)] + common_cflags = [f'-DTORCH_EXTENSION_NAME={name}'] common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') - common_cflags += ['-I{}'.format(include) for include in user_includes] - common_cflags += ['-isystem {}'.format(include) for include in system_includes] + + # Note [Pybind11 ABI constants] + # + # Pybind11 before 2.4 used to build an ABI strings using the following pattern: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" + # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" + # + # This was done in order to further narrow down the chances of compiler ABI incompatibility + # that can cause a hard to debug segfaults. + # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties + # captured during PyTorch native library compilation in torch/csrc/Module.cpp + + for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + pval = getattr(torch._C, f"_PYBIND11_{pname}") + if pval is not None and not IS_WINDOWS: + common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') + + common_cflags += [f'-I{include}' for include in user_includes] + common_cflags += [f'-isystem {include}' for include in system_includes] common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] @@ -1639,9 +1658,9 @@ def object_file_path(source_file: str) -> str: if _is_cuda_file(source_file) and with_cuda: # Use a different object filename in case a C++ and CUDA file have # the same filename but different extension (.cpp vs. .cu). - target = '{}.cuda.o'.format(file_name) + target = f'{file_name}.cuda.o' else: - target = '{}.o'.format(file_name) + target = f'{file_name}.o' return target objects = [object_file_path(src) for src in sources] @@ -1657,7 +1676,7 @@ def object_file_path(source_file: str) -> str: ldflags = _nt_quote_args(ldflags) ext = 'pyd' if IS_WINDOWS else 'so' - library_target = '{}.{}'.format(name, ext) + library_target = f'{name}.{ext}' _write_ninja_file( path=path, @@ -1719,20 +1738,20 @@ def sanitize_flags(flags): # Version 1.3 is required for the `deps` directive. config = ['ninja_required_version = 1.3'] - config.append('cxx = {}'.format(compiler)) + config.append(f'cxx = {compiler}') if with_cuda: if IS_HIP_EXTENSION: nvcc = _join_rocm_home('bin', 'hipcc') else: nvcc = _join_cuda_home('bin', 'nvcc') - config.append('nvcc = {}'.format(nvcc)) + config.append(f'nvcc = {nvcc}') - flags = ['cflags = {}'.format(' '.join(cflags))] - flags.append('post_cflags = {}'.format(' '.join(post_cflags))) + flags = [f'cflags = {" ".join(cflags)}'] + flags.append(f'post_cflags = {" ".join(post_cflags)}') if with_cuda: - flags.append('cuda_cflags = {}'.format(' '.join(cuda_cflags))) - flags.append('cuda_post_cflags = {}'.format(' '.join(cuda_post_cflags))) - flags.append('ldflags = {}'.format(' '.join(ldflags))) + flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') + flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + flags.append(f'ldflags = {" ".join(ldflags)}') # Turn into absolute paths so we can emit them into the ninja build # file wherever it is. @@ -1765,7 +1784,7 @@ def sanitize_flags(flags): object_file = object_file.replace(':', '$:') source_file = source_file.replace(" ", "$ ") object_file = object_file.replace(" ", "$ ") - build.append('build {}: {} {}'.format(object_file, rule, source_file)) + build.append(f'build {object_file}: {rule} {source_file}') if library_target is not None: link_rule = ['rule link'] @@ -1776,15 +1795,13 @@ def sanitize_flags(flags): cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') else: raise RuntimeError("MSVC is required to load C++ extensions") - link_rule.append( - ' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format( - cl_path)) + link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') else: link_rule.append(' command = $cxx $in $ldflags -o $out') - link = ['build {}: link {}'.format(library_target, ' '.join(objects))] + link = [f'build {library_target}: link {" ".join(objects)}'] - default = ['default {}'.format(library_target)] + default = [f'default {library_target}'] else: link_rule, link, default = [], [], [] @@ -1796,7 +1813,7 @@ def sanitize_flags(flags): with open(path, 'w') as build_file: for block in blocks: lines = '\n'.join(block) - build_file.write('{}\n\n'.format(lines)) + build_file.write(f'{lines}\n\n') def _join_cuda_home(*paths) -> str: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 59c1827e1842..8d7726ebd129 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ in `./_utils/worker.py`. """ +import os import threading import itertools import warnings @@ -290,10 +291,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, self._iterator = None + self.check_worker_number_rationality() + def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: + self.check_worker_number_rationality() return _MultiProcessingDataLoaderIter(self) @property @@ -399,6 +403,83 @@ def __len__(self) -> int: else: return len(self._index_sampler) + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + + suggested_max_worker_msg = (( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create.").format( + num_worker_suggest, + ("" if cpuset_checked else " (`cpuset` is not taken into account)")) + ) if num_worker_suggest is not None else ( + "DataLoader is not able to compute a suggested max number of worker in current system.") + + warn_msg = ( + "This DataLoader will create {} worker processes in total. {} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary.").format( + num_worker_created, + suggested_max_worker_msg) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, 'sched_getaffinity'): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satify mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn(_create_warning_msg( + max_num_worker_suggest, + self.num_workers, + cpuset_checked)) + class _BaseDataLoaderIter(object): def __init__(self, loader: DataLoader) -> None: @@ -843,7 +924,7 @@ def _reset(self, loader, first_iter=False): # contains all `True`s if not using an iterable-style dataset # (i.e., if kind != Iterable). # Not that this indicates that a worker still has work to do *for this epoch*. - # It does not mean that a worker is dead. In case of `_persistent_workers`, + # It does not mean that a worker is dead. In case of `_persistent_workers`, # the worker will be reset to available in the next epoch. self._workers_status = [True for i in range(self._num_workers)] # We resume the prefetching in case it was enabled diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index c910cab9aef8..7c45c10dd812 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -164,7 +164,7 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]): tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: - assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) + assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index):