diff --git a/.bazelrc b/.bazelrc index ecfe8fd0efcd0..310eb293389dc 100644 --- a/.bazelrc +++ b/.bazelrc @@ -3,7 +3,11 @@ build --copt=-I. build --copt=-isystem --copt bazel-out/k8-fastbuild/bin # Configuration to disable tty features for environments like CI - build:no-tty --curses no build:no-tty --progress_report_interval 10 build:no-tty --show_progress_rate_limit 10 + +# Configuration to build with GPU support +build:gpu --define=cuda=true +# define a separate build folder for faster switching between configs +build:gpu --platform_suffix=-gpu diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 5a85674d74fe9..46527c1168891 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -8,8 +8,6 @@ ("3.6", [ ("important", [X(True)]), ("parallel_tbb", [X(True)]), - ("parallel_native", [X(True)]), - ("pure_torch", [X(True)]), ]), ]), # TODO: bring back libtorch test @@ -47,26 +45,10 @@ # ]), ]), ]), - ("11.1", [ - ("3.8", [ - ("shard_test", [XImportant(True)]), - # UNCOMMENT THE BELOW TO REENABLE LIBTORCH - # ("libtorch", [ - # (True, [ - # ('build_only', [X(True)]), - # ]), - # ]), - ]), - ]), ]), ]), ("bionic", [ ("clang", [ - ("9", [ - ("3.6", [ - ("noarch", [XImportant(True)]), - ]), - ]), ("9", [ ("3.6", [ ("xla", [XImportant(True)]), @@ -74,20 +56,14 @@ ]), ]), ]), - ("cuda", [ - ("10.2", [ - ("3.9", [ - ("shard_test", [XImportant(True)]), - ]), - ]), - ]), - ("rocm", [ - ("3.9", [ - ("3.6", [ - ('build_only', [XImportant(True)]), - ]), - ]), - ]), + # @jithunnair-amd believes Jenkins builds are sufficient + # ("rocm", [ + # ("3.9", [ + # ("3.6", [ + # ('build_only', [XImportant(True)]), + # ]), + # ]), + # ]), ]), ] diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index bdc977270c22e..305bbb4d354bb 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -214,7 +214,7 @@ def gen_docs_configs(xenial_parent_config): HiddenConf( "pytorch_python_doc_build", parent_build=xenial_parent_config, - filters=gen_filter_dict(branches_list=r"/.*/", + filters=gen_filter_dict(branches_list=["master"], tags_list=RC_PATTERN), ) ) @@ -230,7 +230,7 @@ def gen_docs_configs(xenial_parent_config): HiddenConf( "pytorch_cpp_doc_build", parent_build=xenial_parent_config, - filters=gen_filter_dict(branches_list=r"/.*/", + filters=gen_filter_dict(branches_list=["master"], tags_list=RC_PATTERN), ) ) @@ -241,13 +241,6 @@ def gen_docs_configs(xenial_parent_config): branch="master", ) ) - - configs.append( - HiddenConf( - "pytorch_doc_test", - parent_build=xenial_parent_config - ) - ) return configs @@ -396,16 +389,19 @@ def instantiate_configs(only_slow_gradcheck): if cuda_version == "10.2" and python_version == "3.6" and not is_libtorch and not is_slow_gradcheck: c.dependent_tests = gen_dependent_configs(c) + if ( - compiler_name == "gcc" - and compiler_version == "5.4" + compiler_name != "clang" + and not rocm_version and not is_libtorch and not is_vulkan and not is_pure_torch - and parallel_backend is None + and not is_noarch + and not is_slow_gradcheck + and not only_slow_gradcheck ): - bc_breaking_check = Conf( - "backward-compatibility-check", + distributed_test = Conf( + c.gen_build_name("") + "distributed", [], is_xla=False, restrict_phases=["test"], @@ -413,7 +409,7 @@ def instantiate_configs(only_slow_gradcheck): is_important=True, parent_build=c, ) - c.dependent_tests.append(bc_breaking_check) + c.dependent_tests.append(distributed_test) config_list.append(c) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3a64240fcf8bb..c57eb26c032e1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -644,6 +644,15 @@ jobs: set -ex export SCRIBE_GRAPHQL_ACCESS_TOKEN="${SCRIBE_GRAPHQL_ACCESS_TOKEN}" export JOB_BASE_NAME="$CIRCLE_JOB" + # temporary fix for https://github.com/pytorch/pytorch/issues/60746 + if [ -z "$CIRCLE_PR_NUMBER" ]; then + if [[ $CIRCLE_BRANCH =~ .*pull.* ]]; then + export PR_NUMBER="$(echo $CIRCLE_BRANCH | sed 's/[^0-9]//g')" + export CIRCLE_PR_NUMBER="$PR_NUMBER" + fi + else + export PR_NUMBER="$CIRCLE_PR_NUMBER" + fi ${PARALLEL_FLAGS} cd workspace EOL @@ -7103,7 +7112,8 @@ workflows: - pytorch_python_doc_build: filters: branches: - only: /.*/ + only: + - master tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: @@ -7123,7 +7133,8 @@ workflows: - pytorch_cpp_doc_build: filters: branches: - only: /.*/ + only: + - master tags: only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/ requires: @@ -7140,14 +7151,11 @@ workflows: name: pytorch_cpp_doc_push requires: - pytorch_cpp_doc_build - - pytorch_doc_test: - requires: - - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_linux_test: - name: pytorch_linux_backward_compatibility_check_test + name: pytorch_linux_pytorch_linux_xenial_py3_6_gcc5_4_distributed_test requires: - pytorch_linux_xenial_py3_6_gcc5_4_build - build_environment: "pytorch-linux-backward-compatibility-check-test" + build_environment: "pytorch-linux-pytorch_linux_xenial_py3_6_gcc5_4_distributed-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" resource_class: large - pytorch_linux_build: @@ -7175,43 +7183,13 @@ workflows: build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" resource_class: large - - pytorch_linux_build: - name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build - requires: - - "docker-pytorch-linux-xenial-py3.6-gcc5.4" - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - pytorch_linux_test: - name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_test + name: pytorch_linux_pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_distributed_test requires: - - pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-test" + - pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build + build_environment: "pytorch-linux-pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_distributed-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" resource_class: large - - pytorch_linux_build: - name: pytorch_pure_torch_linux_xenial_py3_6_gcc5_4_build - requires: - - "docker-pytorch-linux-xenial-py3.6-gcc5.4" - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-pure_torch-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc7_build requires: @@ -7237,6 +7215,13 @@ workflows: build_environment: "pytorch-linux-xenial-py3.6-gcc7-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7" resource_class: large + - pytorch_linux_test: + name: pytorch_linux_pytorch_linux_xenial_py3_6_gcc7_distributed_test + requires: + - pytorch_linux_xenial_py3_6_gcc7_build + build_environment: "pytorch-linux-pytorch_linux_xenial_py3_6_gcc7_distributed-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7" + resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang7_asan_build requires: @@ -7371,40 +7356,12 @@ workflows: docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" use_cuda_docker_runtime: "1" resource_class: gpu.medium - - pytorch_linux_build: - name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build - requires: - - "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test1 - requires: - - pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build - build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - pytorch_linux_test: - name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test2 + name: pytorch_linux_pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_distributed_test requires: - - pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build - build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test2" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_build: - name: pytorch_linux_bionic_py3_6_clang9_noarch_build - requires: - - "docker-pytorch-linux-bionic-py3.6-clang9" - build_environment: "pytorch-linux-bionic-py3.6-clang9-noarch-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" - - pytorch_linux_test: - name: pytorch_linux_bionic_py3_6_clang9_noarch_test - requires: - - pytorch_linux_bionic_py3_6_clang9_noarch_build - build_environment: "pytorch-linux-bionic-py3.6-clang9-noarch-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" + - pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_distributed-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7" resource_class: large - pytorch_linux_build: name: pytorch_xla_linux_bionic_py3_6_clang9_build @@ -7432,36 +7389,6 @@ workflows: build_environment: "pytorch-vulkan-linux-bionic-py3.6-clang9-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9" resource_class: large - - pytorch_linux_build: - name: pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_build - requires: - - "docker-pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7" - build_environment: "pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7" - - pytorch_linux_test: - name: pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_test1 - requires: - - pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_build - build_environment: "pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_test2 - requires: - - pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_build - build_environment: "pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7-test2" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_build: - name: pytorch_linux_bionic_rocm3_9_py3_6_build - requires: - - "docker-pytorch-linux-bionic-rocm3.9-py3.6" - build_environment: "pytorch-linux-bionic-rocm3.9-py3.6-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-rocm3.9-py3.6" - resource_class: xlarge - build_only: "1" - pytorch_macos_10_15_py3_build: name: pytorch_macos_10_15_py3_build - pytorch_macos_10_13_py3_build: @@ -9329,37 +9256,30 @@ workflows: name: "docker-pytorch-linux-xenial-py3.6-gcc7" image_name: "pytorch-linux-xenial-py3.6-gcc7" - pytorch_linux_build: - name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build + name: pytorch_linux_xenial_py3_6_gcc5_4_build requires: - "docker-pytorch-linux-xenial-py3.6-gcc5.4" - build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-build" + build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - - pytorch_linux_test: - name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_test + - pytorch_python_doc_build: requires: - - pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build - build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - resource_class: large + - pytorch_linux_xenial_py3_6_gcc5_4_build + - pytorch_cpp_doc_build: + requires: + - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_linux_build: - name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build + name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build requires: - "docker-pytorch-linux-xenial-py3.6-gcc5.4" - build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-build" + build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - pytorch_linux_test: - name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_test + name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_test requires: - - pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build - build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-test" + - pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build + build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" resource_class: large - - pytorch_linux_build: - name: pytorch_pure_torch_linux_xenial_py3_6_gcc5_4_build - requires: - - "docker-pytorch-linux-xenial-py3.6-gcc5.4" - build_environment: "pytorch-pure_torch-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4" - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc7_build requires: diff --git a/.circleci/docker/README.md b/.circleci/docker/README.md index a87522f622ccd..cc4f97cfae748 100644 --- a/.circleci/docker/README.md +++ b/.circleci/docker/README.md @@ -27,5 +27,5 @@ Docker builds are now defined with `.circleci/cimodel/data/simple/docker_definit ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest # Set flags (see build.sh) and build image -sudo bash -c 'BREAKPAD=1 ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest +sudo bash -c 'PROTOBUF=1 ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest ``` diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 7c8477349981a..18d19ae5d586f 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -78,106 +78,108 @@ TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/u case "$image" in pytorch-linux-xenial-py3.8) ANACONDA_PYTHON_VERSION=3.8 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 # Do not install PROTOBUF, DB, and VISION as a test ;; pytorch-linux-xenial-py3.6-gcc5.4) ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=5 PROTOBUF=yes DB=yes VISION=yes KATEX=yes - BREAKPAD=yes ;; pytorch-linux-xenial-py3.6-gcc7.2) ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 # Do not install PROTOBUF, DB, and VISION as a test ;; pytorch-linux-xenial-py3.6-gcc7) ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7) CUDA_VERSION=10.2 CUDNN_VERSION=7 ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 PROTOBUF=yes DB=yes VISION=yes KATEX=yes - BREAKPAD=yes ;; pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7) CUDA_VERSION=11.1 CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 PROTOBUF=yes DB=yes VISION=yes KATEX=yes - BREAKPAD=yes ;; pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7) CUDA_VERSION=11.3.0 # Deviating from major.minor to conform to nvidia's Docker image names CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 GCC_VERSION=7 PROTOBUF=yes DB=yes VISION=yes KATEX=yes - BREAKPAD=yes ;; pytorch-linux-xenial-py3-clang5-asan) ANACONDA_PYTHON_VERSION=3.6 CLANG_VERSION=5.0 + CMAKE_VERSION=3.10.3 PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-xenial-py3-clang7-asan) ANACONDA_PYTHON_VERSION=3.6 CLANG_VERSION=7 + CMAKE_VERSION=3.10.3 PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-xenial-py3-clang7-onnx) ANACONDA_PYTHON_VERSION=3.6 CLANG_VERSION=7 + CMAKE_VERSION=3.10.3 PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-xenial-py3-clang5-android-ndk-r19c) ANACONDA_PYTHON_VERSION=3.6 CLANG_VERSION=5.0 + CMAKE_VERSION=3.10.3 LLVMDEV=yes PROTOBUF=yes ANDROID=yes ANDROID_NDK_VERSION=r19c GRADLE_VERSION=6.8.3 - CMAKE_VERSION=3.7.0 NINJA_VERSION=1.9.0 ;; pytorch-linux-xenial-py3.6-clang7) ANACONDA_PYTHON_VERSION=3.6 + CMAKE_VERSION=3.10.3 CLANG_VERSION=7 PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-bionic-py3.6-clang9) ANACONDA_PYTHON_VERSION=3.6 @@ -185,7 +187,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes VULKAN_SDK_VERSION=1.2.162.1 SWIFTSHADER=yes ;; @@ -195,8 +196,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes - BREAKPAD=yes ;; pytorch-linux-bionic-cuda10.2-cudnn7-py3.6-clang9) CUDA_VERSION=10.2 @@ -206,7 +205,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7) CUDA_VERSION=10.2 @@ -216,7 +214,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ;; pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9) CUDA_VERSION=11.0 @@ -226,7 +223,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ROCM_VERSION=3.9 ;; pytorch-linux-bionic-rocm4.0.1-py3.6) @@ -235,7 +231,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ROCM_VERSION=4.0.1 ;; pytorch-linux-bionic-rocm4.1-py3.6) @@ -244,7 +239,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ROCM_VERSION=4.1 ;; pytorch-linux-bionic-rocm4.2-py3.6) @@ -253,7 +247,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes ROCM_VERSION=4.2 ;; *) @@ -261,8 +254,10 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - BREAKPAD=yes echo "image '$image' did not match an existing build configuration" + if [[ "$image" == *xenial* ]]; then + CMAKE_VERSION=3.10.3 + fi if [[ "$image" == *py* ]]; then extract_version_from_image_name py ANACONDA_PYTHON_VERSION fi @@ -325,7 +320,6 @@ docker build \ --build-arg "GCC_VERSION=${GCC_VERSION}" \ --build-arg "CUDA_VERSION=${CUDA_VERSION}" \ --build-arg "CUDNN_VERSION=${CUDNN_VERSION}" \ - --build-arg "BREAKPAD=${BREAKPAD}" \ --build-arg "ANDROID=${ANDROID}" \ --build-arg "ANDROID_NDK=${ANDROID_NDK_VERSION}" \ --build-arg "GRADLE_VERSION=${GRADLE_VERSION}" \ diff --git a/.circleci/docker/common/install_breakpad.sh b/.circleci/docker/common/install_breakpad.sh deleted file mode 100644 index f49f1fb325e2a..0000000000000 --- a/.circleci/docker/common/install_breakpad.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -set -ex - -git clone https://github.com/driazati/breakpad.git -pushd breakpad - -# breakpad has no actual releases, so this is pinned to the top commit from -# main when this was forked (including the one patch commit). This uses a fork -# of the breakpad mainline that automatically daisy-chains out to any previously -# installed signal handlers (instead of overwriting them). -git checkout 5485e473ed46d065e05489e50dfc59d90dfd7e22 - -git clone https://chromium.googlesource.com/linux-syscall-support src/third_party/lss -pushd src/third_party/lss -# same as with breakpad, there are no real releases for this repo so use a -# commit as the pin -git checkout e1e7b0ad8ee99a875b272c8e33e308472e897660 -popd - -./configure -make -make install -popd -rm -rf breakpad diff --git a/.circleci/docker/common/install_cmake.sh b/.circleci/docker/common/install_cmake.sh index 3ef71031db38f..5aa564d7c478c 100755 --- a/.circleci/docker/common/install_cmake.sh +++ b/.circleci/docker/common/install_cmake.sh @@ -4,6 +4,9 @@ set -ex [ -n "$CMAKE_VERSION" ] +# Remove system cmake install so it won't get used instead +apt-get remove cmake -y + # Turn 3.6.3 into v3.6 path=$(echo "${CMAKE_VERSION}" | sed -e 's/\([0-9].[0-9]\+\).*/v\1/') file="cmake-${CMAKE_VERSION}-Linux-x86_64.tar.gz" diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index 86dbb153b2925..f12ae38aa58bd 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -69,8 +69,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then } # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README - # DO NOT install cmake here as it would install a version newer than 3.5, but - # we want to pin to version 3.5. + # DO NOT install cmake here as it would install a version newer than 3.10, but + # we want to pin to version 3.10. SCIPY_VERSION=1.1.0 if [ "$ANACONDA_PYTHON_VERSION" = "3.9" ]; then # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source diff --git a/.circleci/docker/common/install_openmpi.sh b/.circleci/docker/common/install_openmpi.sh index 7bd32c71f16fb..8c45279b8b464 100644 --- a/.circleci/docker/common/install_openmpi.sh +++ b/.circleci/docker/common/install_openmpi.sh @@ -1,4 +1,10 @@ #!/bin/bash sudo apt-get update +# also install ssh to avoid error of: +# -------------------------------------------------------------------------- +# The value of the MCA parameter "plm_rsh_agent" was set to a path +# that could not be found: +# plm_rsh_agent: ssh : rsh +sudo apt-get install -y ssh sudo apt-get install -y --allow-downgrades --allow-change-held-packages openmpi-bin libopenmpi-dev diff --git a/.circleci/docker/ubuntu-cuda/Dockerfile b/.circleci/docker/ubuntu-cuda/Dockerfile index e0e7dc9b6e5bf..84075db161358 100644 --- a/.circleci/docker/ubuntu-cuda/Dockerfile +++ b/.circleci/docker/ubuntu-cuda/Dockerfile @@ -61,6 +61,16 @@ RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi RUN rm install_vision.sh ENV INSTALLED_VISION ${VISION} +ADD ./common/install_openssl.sh install_openssl.sh +ENV OPENSSL_ROOT_DIR /opt/openssl +RUN bash ./install_openssl.sh + +# (optional) Install non-default CMake version +ARG CMAKE_VERSION +ADD ./common/install_cmake.sh install_cmake.sh +RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi +RUN rm install_cmake.sh + # Install ccache/sccache (do this last, so we get priority in PATH) ADD ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH @@ -88,9 +98,5 @@ ENV TORCH_NVCC_FLAGS "-Xfatbin -compress-all" # Install LLVM dev version (Defined in the pytorch/builder github repository) COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm -ADD ./common/install_openssl.sh install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl -RUN bash ./install_openssl.sh - USER jenkins CMD ["bash"] diff --git a/.circleci/docker/ubuntu/Dockerfile b/.circleci/docker/ubuntu/Dockerfile index ea00c083c3d02..76a64bc0ea10d 100644 --- a/.circleci/docker/ubuntu/Dockerfile +++ b/.circleci/docker/ubuntu/Dockerfile @@ -82,13 +82,6 @@ RUN rm AndroidManifest.xml RUN rm build.gradle ENV INSTALLED_ANDROID ${ANDROID} -# (optional) Install breakpad -ARG BREAKPAD -ADD ./common/install_breakpad.sh install_breakpad.sh -RUN if [ -n "${BREAKPAD}" ]; then bash ./install_breakpad.sh; fi -RUN rm install_breakpad.sh -ENV INSTALLED_BREAKPAD ${BREAKPAD} - # (optional) Install Vulkan SDK ARG VULKAN_SDK_VERSION ADD ./common/install_vulkan_sdk.sh install_vulkan_sdk.sh @@ -113,6 +106,10 @@ ADD ./common/install_ninja.sh install_ninja.sh RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi RUN rm install_ninja.sh +ADD ./common/install_openssl.sh install_openssl.sh +RUN bash ./install_openssl.sh +ENV OPENSSL_ROOT_DIR /opt/openssl + # Install ccache/sccache (do this last, so we get priority in PATH) ADD ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH @@ -130,9 +127,5 @@ ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} # Install LLVM dev version (Defined in the pytorch/builder github repository) COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl - USER jenkins CMD ["bash"] diff --git a/.circleci/scripts/binary_macos_build.sh b/.circleci/scripts/binary_macos_build.sh index c402cdd008013..c5cdfa9f09080 100755 --- a/.circleci/scripts/binary_macos_build.sh +++ b/.circleci/scripts/binary_macos_build.sh @@ -14,6 +14,9 @@ chmod +x "$build_script" # Build cat >"$build_script" < str: workflow_name = filename.with_suffix("").name.replace("_", "-") if workflow_name.startswith("generated-"): workflow_name = workflow_name[len("generated-"):] - return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}" + return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}" \ + "-${{ github.event_name == 'workflow_dispatch' }}" def should_check(filename: Path) -> bool: diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 0d6844bf8dadc..b5146114054a6 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -29,24 +29,30 @@ LINUX_CUDA_TEST_RUNNER, } +CUDA_RUNNERS = { + WINDOWS_CUDA_TEST_RUNNER, + LINUX_CUDA_TEST_RUNNER, +} +CPU_RUNNERS = { + WINDOWS_CPU_TEST_RUNNER, + LINUX_CPU_TEST_RUNNER, +} + +LABEL_CIFLOW_ALL = "ciflow/all" +LABEL_CIFLOW_BAZEL = "ciflow/bazel" +LABEL_CIFLOW_COVERAGE = "ciflow/coverage" +LABEL_CIFLOW_CPU = "ciflow/cpu" +LABEL_CIFLOW_CUDA = "ciflow/cuda" +LABEL_CIFLOW_DEFAULT = "ciflow/default" +LABEL_CIFLOW_LIBTORCH = "ciflow/libtorch" +LABEL_CIFLOW_LINUX = "ciflow/linux" +LABEL_CIFLOW_SCHEDULED = "ciflow/scheduled" +LABEL_CIFLOW_SLOW = "ciflow/slow" +LABEL_CIFLOW_WIN = "ciflow/win" +LABEL_CIFLOW_XLA = "ciflow/xla" +LABEL_CIFLOW_NOARCH = "ciflow/noarch" + -# TODO: ------------- Remove the comment once fully rollout ------------------- -# Rollout Strategy: -# 1. Manual Phase -# step 1. Add 'ciflow/default' label to the PR -# step 2. Once there's an [unassigned] event from PR, it should rerun -# step 3. Remove 'ciflow/default' label -# step 4. Trigger the [unassigned] event again, it should not rerun -# 2. Probot Phase 1 (manual on 1 workflow) -# step 1. Probot automatically add labels based on the context -# step 2. Manually let probot trigger [unassigned] event -# 3. Probot Phase 2 (auto on 1 workflows) -# step 1. Modify the workflows so that they only listen on [unassigned] events -# step 2. Probot automatically adds labels automatically based on the context -# step 3. Probot automatically triggers [unassigned] event -# 4. Probot Phase 3 (auto on many workflows) -# step 1. Enable it for all workflows -# ----------------------------------------------------------------------- @dataclass class CIFlowConfig: enabled: bool = False @@ -67,11 +73,11 @@ def gen_root_job_condition(self) -> None: # Once fully rollout, we can have strict constraints # e.g. ADD env.GITHUB_ACTOR == '{self.trigger_actor} # REMOVE github.event.action !='{self.trigger_action}' - label_conditions = [f"github.event.action == '{self.trigger_action}'"] + \ - [f"contains(github.event.pull_request.labels.*.name, '{label}')" for label in self.labels] + label_conditions = [ + f"contains(github.event.pull_request.labels.*.name, '{label}')" for label in sorted(self.labels)] self.root_job_condition = f"(github.event_name != 'pull_request') || " \ f"(github.event.action !='{self.trigger_action}') || " \ - f"({' && '.join(label_conditions)})" + f"({' || '.join(label_conditions)})" def reset_root_job(self) -> None: self.root_job_name = '' @@ -81,6 +87,7 @@ def __post_init__(self) -> None: if not self.enabled: self.reset_root_job() return + self.labels.add(LABEL_CIFLOW_ALL) self.gen_root_job_condition() @@ -98,7 +105,9 @@ def add_label_rule(self, labels: Set[str], workflow_name: str) -> None: self.label_rules[label] = {workflow_name} def generate_json(self) -> None: + GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file output = { + "__comment": f"@{GENERATED} DO NOT EDIT MANUALLY, Generation script: .github/scripts/generate_ci_workflows.py", "version": self.version, "label_rules": { label: sorted(list(workflows)) @@ -131,14 +140,20 @@ class CIWorkflow: only_build_on_pull_request: bool = False only_run_smoke_tests_on_pull_request: bool = False num_test_shards_on_pull_request: int = -1 + distributed_test: bool = True # The following variables will be set as environment variables, # so it's easier for both shell and Python scripts to consume it if false is represented as the empty string. enable_jit_legacy_test: YamlShellBool = "''" + enable_distributed_test: YamlShellBool = "''" enable_multigpu_test: YamlShellBool = "''" enable_nogpu_no_avx_test: YamlShellBool = "''" enable_nogpu_no_avx2_test: YamlShellBool = "''" enable_slow_test: YamlShellBool = "''" + enable_docs_test: YamlShellBool = "''" + enable_backwards_compat_test: YamlShellBool = "''" + enable_xla_test: YamlShellBool = "''" + enable_noarch_test: YamlShellBool = "''" def __post_init__(self) -> None: if self.is_libtorch: @@ -147,6 +162,9 @@ def __post_init__(self) -> None: if not self.on_pull_request: self.only_build_on_pull_request = False + if self.distributed_test: + self.enable_distributed_test = 1 + # If num_test_shards_on_pull_request is not user-defined, default to num_test_shards unless we are # only running smoke tests on the pull request. if self.num_test_shards_on_pull_request == -1: @@ -155,7 +173,6 @@ def __post_init__(self) -> None: self.num_test_shards_on_pull_request = 1 else: self.num_test_shards_on_pull_request = self.num_test_shards - self.assert_valid() def assert_valid(self) -> None: @@ -165,13 +182,30 @@ def assert_valid(self) -> None: if self.arch == 'windows': assert self.test_runner_type in WINDOWS_RUNNERS, err_message + if self.ciflow_config.enabled: + # make sure if LABEL_CIFLOW_DEFAULT is set, we then need to set trigger_action_only to False + assert self.ciflow_config.trigger_action_only != (LABEL_CIFLOW_DEFAULT in self.ciflow_config.labels) + assert self.on_pull_request + assert LABEL_CIFLOW_ALL in self.ciflow_config.labels + assert LABEL_CIFLOW_ALL in self.ciflow_config.root_job_condition + if self.arch == 'linux': + assert LABEL_CIFLOW_LINUX in self.ciflow_config.labels + if self.arch == 'windows': + assert LABEL_CIFLOW_WIN in self.ciflow_config.labels + if self.test_runner_type in CUDA_RUNNERS: + assert LABEL_CIFLOW_CUDA in self.ciflow_config.labels + if self.test_runner_type in CPU_RUNNERS: + assert LABEL_CIFLOW_CPU in self.ciflow_config.labels + def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = GITHUB_DIR / f"workflows/generated-{self.build_environment}.yml" with open(output_file_path, "w") as output_file: GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"]) - output_file.write(workflow_template.render(asdict(self))) - output_file.write("\n") + content = workflow_template.render(asdict(self)) + output_file.write(content) + if content[-1] != "\n": + output_file.write("\n") print(output_file_path) @@ -183,6 +217,10 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: test_runner_type=WINDOWS_CPU_TEST_RUNNER, on_pull_request=True, num_test_shards=2, + ciflow_config=CIFlowConfig( + enabled=True, + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_CPU, LABEL_CIFLOW_WIN} + ), ), CIWorkflow( arch="windows", @@ -192,18 +230,28 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: on_pull_request=True, only_run_smoke_tests_on_pull_request=True, num_test_shards=2, + ciflow_config=CIFlowConfig( + enabled=True, + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_CUDA, LABEL_CIFLOW_WIN} + ), ), CIWorkflow( arch="windows", - build_environment="win-vs2019-cuda11.1-py3", - cuda_version="11.1", + build_environment="win-vs2019-cuda11.3-py3", + cuda_version="11.3", test_runner_type=WINDOWS_CUDA_TEST_RUNNER, num_test_shards=2, + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels={LABEL_CIFLOW_CUDA, LABEL_CIFLOW_WIN} + ), ), CIWorkflow( arch="windows", - build_environment="periodic-win-vs2019-cuda11.3-py3", - cuda_version="11.3", + build_environment="periodic-win-vs2019-cuda11.1-py3", + cuda_version="11.1", test_runner_type=WINDOWS_CUDA_TEST_RUNNER, num_test_shards=2, is_scheduled="45 0,4,8,12,16,20 * * *", @@ -211,7 +259,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: ciflow_config=CIFlowConfig( enabled=True, trigger_action_only=True, - labels={'ciflow/scheduled'} + labels={LABEL_CIFLOW_SCHEDULED, LABEL_CIFLOW_WIN, LABEL_CIFLOW_CUDA} ), ), ] @@ -224,26 +272,54 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: test_runner_type=LINUX_CPU_TEST_RUNNER, on_pull_request=True, enable_doc_jobs=True, + enable_docs_test=1, + enable_backwards_compat_test=1, num_test_shards=2, + ciflow_config=CIFlowConfig( + enabled=True, + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU} + ), ), # CIWorkflow( # arch="linux", # build_environment="paralleltbb-linux-xenial-py3.6-gcc5.4", # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc5.4", # test_runner_type=LINUX_CPU_TEST_RUNNER, + # on_pull_request=True, + # ciflow_config=CIFlowConfig( + # enabled=True, + # trigger_action_only=True, + # labels={LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}, + # ), # ), - # CIWorkflow( - # arch="linux", - # build_environment="parallelnative-linux-xenial-py3.6-gcc5.4", - # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc5.4", - # test_runner_type=LINUX_CPU_TEST_RUNNER, - # ), - # CIWorkflow( - # arch="linux", - # build_environment="pure_torch-linux-xenial-py3.6-gcc5.4", - # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc5.4", - # test_runner_type=LINUX_CPU_TEST_RUNNER, - # ), + CIWorkflow( + arch="linux", + build_environment="parallelnative-linux-xenial-py3.6-gcc5.4", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc5.4", + test_runner_type=LINUX_CPU_TEST_RUNNER, + # This is a master only job despite on_pull_request is set to True + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels={LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}, + ), + ), + # Build PyTorch with BUILD_CAFFE2=OFF + CIWorkflow( + arch="linux", + build_environment="puretorch-linux-xenial-py3.6-gcc5.4", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc5.4", + test_runner_type=LINUX_CPU_TEST_RUNNER, + exclude_test=True, + # This is a master only job despite on_pull_request is set to True + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels={LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}, + ), + ), # CIWorkflow( # arch="linux", # build_environment="linux-xenial-py3.6-gcc7", @@ -268,6 +344,12 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, num_test_shards=2, + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels={LABEL_CIFLOW_SLOW, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA} + ), ), CIWorkflow( arch="linux", @@ -284,7 +366,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: ciflow_config=CIFlowConfig( enabled=True, trigger_action_only=True, - labels=set(['ciflow/slow']), + labels=set([LABEL_CIFLOW_SLOW, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA]), ), ), CIWorkflow( @@ -293,25 +375,42 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, is_libtorch=True, + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels=set([LABEL_CIFLOW_LIBTORCH, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA]), + ), ), CIWorkflow( arch="linux", - build_environment="linux-xenial-cuda11.1-py3.6-gcc7", - docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7", + build_environment="linux-xenial-cuda11.3-py3.6-gcc7", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, num_test_shards=2, + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + labels=set([LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA]), + ), ), CIWorkflow( arch="linux", - build_environment="libtorch-linux-xenial-cuda11.1-py3.6-gcc7", - docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7", + build_environment="libtorch-linux-xenial-cuda11.3-py3.6-gcc7", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, is_libtorch=True, + on_pull_request=True, + ciflow_config=CIFlowConfig( + enabled=True, + trigger_action_only=True, + labels=set([LABEL_CIFLOW_LIBTORCH, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA]), + ), ), CIWorkflow( arch="linux", - build_environment="periodic-linux-xenial-cuda11.3-py3.6-gcc7", - docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7", + build_environment="periodic-linux-xenial-cuda11.1-py3.6-gcc7", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, num_test_shards=2, is_scheduled="45 0,4,8,12,16,20 * * *", @@ -319,13 +418,13 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: ciflow_config=CIFlowConfig( enabled=True, trigger_action_only=True, - labels={'ciflow/scheduled'} + labels={LABEL_CIFLOW_SCHEDULED, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CUDA} ), ), CIWorkflow( arch="linux", - build_environment="periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7", - docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7", + build_environment="periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7", test_runner_type=LINUX_CUDA_TEST_RUNNER, is_libtorch=True, is_scheduled="45 0,4,8,12,16,20 * * *", @@ -333,27 +432,9 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: ciflow_config=CIFlowConfig( enabled=True, trigger_action_only=True, - labels={'ciflow/scheduled'}, + labels={LABEL_CIFLOW_SCHEDULED, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_LIBTORCH, LABEL_CIFLOW_CUDA}, ), ), - # CIWorkflow( - # arch="linux", - # build_environment="linux-bionic-py3.6-clang9-noarch", - # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-py3.6-clang9", - # test_runner_type=LINUX_CPU_TEST_RUNNER, - # ), - # CIWorkflow( - # arch="linux", - # build_environment="xla-linux-bionic-py3.6-clang9", - # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-py3.6-clang9", - # test_runner_type=LINUX_CPU_TEST_RUNNER, - # ), - # CIWorkflow( - # arch="linux", - # build_environment="vulkan-linux-bionic-py3.6-clang9", - # docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-py3.6-clang9", - # test_runner_type=LINUX_CPU_TEST_RUNNER, - # ), CIWorkflow( arch="linux", build_environment="linux-bionic-py3.8-gcc9-coverage", @@ -364,7 +445,21 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: num_test_shards=2, ciflow_config=CIFlowConfig( enabled=True, - labels=set(['ciflow/default']), + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_COVERAGE, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}, + ), + ), + CIWorkflow( + arch="linux", + build_environment="linux-bionic-py3.6-clang9", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-py3.6-clang9", + test_runner_type=LINUX_CPU_TEST_RUNNER, + on_pull_request=True, + num_test_shards=2, + distributed_test=False, + enable_noarch_test=1, + ciflow_config=CIFlowConfig( + enabled=True, + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU, LABEL_CIFLOW_XLA, LABEL_CIFLOW_NOARCH}, ), ), # CIWorkflow( @@ -428,12 +523,12 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: CIWorkflow( arch="linux", build_environment="linux-xenial-py3.6-gcc7-bazel-test", - docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc7", + docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7", test_runner_type=LINUX_CPU_TEST_RUNNER, on_pull_request=True, ciflow_config=CIFlowConfig( enabled=True, - labels=set(['ciflow/default']), + labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_BAZEL, LABEL_CIFLOW_CPU, LABEL_CIFLOW_LINUX}, ), ), ] @@ -442,6 +537,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: jinja_env = jinja2.Environment( variable_start_string="!{{", loader=jinja2.FileSystemLoader(str(GITHUB_DIR.joinpath("templates"))), + undefined=jinja2.StrictUndefined, ) template_and_workflows = [ (jinja_env.get_template("linux_ci_workflow.yml.j2"), LINUX_WORKFLOWS), @@ -465,8 +561,8 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: ciflow_ruleset.add_label_rule(workflow.ciflow_config.labels, workflow.build_environment) elif workflow.on_pull_request: # If ciflow is disabled but still on_pull_request, we can denote - # it as a special label 'ciflow/default' in the ruleset, which will be later - # turned into an actual 'ciflow/default' label in the workflow. - # During the rollout phase, it has the same effect as 'ciflow/default' - ciflow_ruleset.add_label_rule({'ciflow/default'}, workflow.build_environment) + # it as a special label LABEL_CIFLOW_DEFAULT in the ruleset, which will be later + # turned into an actual LABEL_CIFLOW_DEFAULT label in the workflow. + # During the rollout phase, it has the same effect as LABEL_CIFLOW_DEFAULT + ciflow_ruleset.add_label_rule({LABEL_CIFLOW_DEFAULT}, workflow.build_environment) ciflow_ruleset.generate_json() diff --git a/.github/scripts/generate_pytorch_test_matrix.py b/.github/scripts/generate_pytorch_test_matrix.py index d8860a02a5c37..beb1b9d90e62f 100755 --- a/.github/scripts/generate_pytorch_test_matrix.py +++ b/.github/scripts/generate_pytorch_test_matrix.py @@ -51,8 +51,18 @@ def main() -> None: configs['nogpu_NO_AVX'] = {'num_shards': 1, 'runner': NOGPU_RUNNER_TYPE} if NOGPU_RUNNER_TYPE is not None and os.getenv('ENABLE_NOGPU_NO_AVX2_TEST'): configs['nogpu_NO_AVX2'] = {'num_shards': 1, 'runner': NOGPU_RUNNER_TYPE} + if os.getenv('ENABLE_DISTRIBUTED_TEST'): + configs['distributed'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} if os.getenv('ENABLE_SLOW_TEST'): configs['slow'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} + if os.getenv('ENABLE_DOCS_TEST'): + configs['docs_test'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} + if os.getenv('ENABLE_BACKWARDS_COMPAT_TEST'): + configs['backwards_compat'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} + if os.getenv('ENABLE_XLA_TEST'): + configs['xla'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} + if os.getenv('ENABLE_NOARCH_TEST'): + configs['noarch'] = {'num_shards': 1, 'runner': TEST_RUNNER_TYPE} matrix = { 'include': [ { diff --git a/.github/scripts/kill_active_ssh_sessions.ps1 b/.github/scripts/kill_active_ssh_sessions.ps1 new file mode 100644 index 0000000000000..09cc63e94bc1f --- /dev/null +++ b/.github/scripts/kill_active_ssh_sessions.ps1 @@ -0,0 +1,11 @@ +function Get-SSH-Sessions { + Get-Process sshd -IncludeUserName | + Where-Object UserName -notLike "*SYSTEM*" | + Select-Object Id +} + +$runningSessions = Get-SSH-Sessions + +foreach ($session in $runningSessions) { + Stop-Process -id $session.Id +} diff --git a/.github/scripts/wait_for_ssh_to_drain.ps1 b/.github/scripts/wait_for_ssh_to_drain.ps1 new file mode 100644 index 0000000000000..ab3ab41f355ce --- /dev/null +++ b/.github/scripts/wait_for_ssh_to_drain.ps1 @@ -0,0 +1,17 @@ +function Get-SSH-Users { + # Gets ssh sessions for all users not named SYSTEM + Get-CimInstance -ClassName Win32_Process -Filter "Name = 'sshd.exe'" | + Get-CimAssociatedInstance -Association Win32_SessionProcess | + Get-CimAssociatedInstance -Association Win32_LoggedOnUser | + Where-Object {$_.Name -ne 'SYSTEM'} | + Measure-Object +} + +$usersLoggedOn = Get-SSH-Users + +Write-Output "Holding runner until all ssh sessions have logged out" +while ($usersLoggedOn.Count -gt 0) { + $usersLoggedOn = Get-SSH-Users + Write-Output "." + Start-Sleep -s 5 +} diff --git a/.github/templates/bazel_ci_workflow.yml.j2 b/.github/templates/bazel_ci_workflow.yml.j2 index 016a11bc39277..9f982cdd5cb61 100644 --- a/.github/templates/bazel_ci_workflow.yml.j2 +++ b/.github/templates/bazel_ci_workflow.yml.j2 @@ -29,21 +29,10 @@ on: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: !{{ build_environment }}-build-and-test NUM_TEST_SHARDS: !{{ num_test_shards }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list - submodules: recursive + !{{ common.setup_ec2_linux() }} + !{{ common.checkout_pytorch("recursive") }} - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" @@ -73,9 +62,10 @@ on: -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e PR_LABELS \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ - -e http_proxy="!{{squid_proxy}}" -e https_proxy="!{{squid_proxy}}" -e no_proxy="!{{squid_no_proxy}}" \ + -e http_proxy="!{{ common.squid_proxy }}" -e https_proxy="!{{ common.squid_proxy }}" -e no_proxy="!{{ common.squid_no_proxy }}" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -85,9 +75,7 @@ on: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" \ sh -c 'sudo chown -R jenkins . && sudo chown -R jenkins /dev && .jenkins/pytorch/build.sh' - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py + !{{ common.parse_ref() }} - name: Display and upload binary build size statistics (Click Me) # temporary hack: set CIRCLE_* vars, until we update # tools/stats/print_test_stats.py to natively support GitHub Actions @@ -122,7 +110,9 @@ on: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ - -e http_proxy="!{{squid_proxy}}" -e https_proxy="!{{squid_proxy}}" -e no_proxy="!{{squid_no_proxy}}" \ + -e CONTINUE_THROUGH_ERROR \ + -e PR_LABELS \ + -e http_proxy="!{{ common.squid_proxy }}" -e https_proxy="!{{ common.squid_proxy }}" -e no_proxy="!{{ common.squid_no_proxy }}" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -156,71 +146,6 @@ on: if-no-files-found: error path: test-reports-*.zip - - name: Clean up docker images - if: always() - run: | - # Prune all of the docker images - docker system prune -af -{%- endblock %} -{% block render_test_results +%} - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [build-and-test, !{{ ciflow_config.root_job_name }}] - if: ${{ needs.build-and-test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - - name: Display and upload test statistics (Click Me) - # temporary hack: set CIRCLE_* vars, until we update - # tools/stats/print_test_stats.py to natively support GitHub Actions - env: - AWS_DEFAULT_REGION: us-east-1 - CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: !{{ build_environment }}-build-and-test - CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} - CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} - CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' - run: | - python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + !{{ common.upload_test_statistics(build_environment) }} + !{{ common.teardown_ec2_linux() }} {%- endblock %} diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 new file mode 100644 index 0000000000000..aff01377ff665 --- /dev/null +++ b/.github/templates/common.yml.j2 @@ -0,0 +1,109 @@ +{%- set upload_artifact_s3_action = "seemethere/upload-artifact-s3@v3" -%} + +{# squid_proxy is an private ELB that only available for GHA custom runners #} +{%- set squid_proxy = "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -%} +{# squid_no_proxy is a list of common set of fixed domains or IPs that we don't need to proxy. See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/http_proxy_config.html#windows-proxy #} +{%- set squid_no_proxy = "localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" -%} + +{%- macro concurrency(build_environment) -%} +concurrency: + group: !{{ build_environment }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true +{%- endmacro -%} + +{%- macro display_ec2_information() -%} + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" +{%- endmacro -%} + +{%- macro parse_ref() -%} + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py +{%- endmacro -%} + +{%- macro upload_test_statistics(build_environment) -%} + - name: Display and upload test statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + JOB_BASE_NAME: !{{ build_environment }}-test + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash + run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 + python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test +{%- endmacro -%} + +{%- macro setup_ec2_linux() -%} + !{{ display_ec2_information() }} + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" +{%- endmacro -%} + +{%- macro teardown_ec2_linux() -%} + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . +{%- endmacro -%} + +{%- macro checkout_pytorch(submodules) -%} + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: !{{ submodules }} +{%- endmacro -%} diff --git a/.github/templates/linux_ci_workflow.yml.j2 b/.github/templates/linux_ci_workflow.yml.j2 index ec39ef6f5f260..520a6a00a19f6 100644 --- a/.github/templates/linux_ci_workflow.yml.j2 +++ b/.github/templates/linux_ci_workflow.yml.j2 @@ -1,7 +1,4 @@ -{# squid_proxy is an private ELB that only available for GHA custom runners #} -{%- set squid_proxy = "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -%} -{# squid_no_proxy is a list of common set of fixed domains or IPs that we don't need to proxy. See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/http_proxy_config.html#windows-proxy #} -{%- set squid_no_proxy = "localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" -%} +{% import 'common.yml.j2' as common %} {%- block name -%} # Template is at: .github/templates/linux_ci_workflow.yml.j2 @@ -38,6 +35,7 @@ env: BUILD_ENVIRONMENT: !{{ build_environment }} DOCKER_IMAGE_BASE: !{{ docker_image_base }} SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -45,19 +43,22 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} -concurrency: - group: !{{ build_environment }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true +!{{ common.concurrency(build_environment) }} jobs: {%- if ciflow_config.enabled %} !{{ ciflow_config.root_job_name }}: runs-on: ubuntu-18.04 if: ${{ !{{ ciflow_config.root_job_condition }} }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running !{{ ciflow_config.root_job_name }} + - name: print labels + run: echo "${LABELS}" {%- endif %} calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} @@ -71,24 +72,8 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - rm -f ~/.ssh/authorized_keys - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow use of git merge-base - fetch-depth: 0 + !{{ common.setup_ec2_linux() }} + !{{ common.checkout_pytorch("false") }} - name: Calculate docker image tag id: calculate-tag run: | @@ -143,34 +128,11 @@ jobs: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: !{{ build_environment }}-build steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - rm -f ~/.ssh/authorized_keys - - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" - uses: seemethere/add-github-ssh-key@v1 - with: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list - submodules: recursive + !{{ common.setup_ec2_linux() }} + !{{ common.checkout_pytorch("recursive") }} - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -178,10 +140,12 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ - -e http_proxy="!{{squid_proxy}}" -e https_proxy="!{{squid_proxy}}" -e no_proxy="!{{squid_no_proxy}}" \ + -e PR_LABELS \ + -e http_proxy="!{{ common.squid_proxy }}" -e https_proxy="!{{ common.squid_proxy }}" -e no_proxy="!{{ common.squid_no_proxy }}" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -191,9 +155,7 @@ jobs: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" \ sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh' - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py + !{{ common.parse_ref() }} - name: Display and upload binary build size statistics (Click Me) # temporary hack: set CIRCLE_* vars, until we update # tools/stats/print_test_stats.py to natively support GitHub Actions @@ -217,19 +179,8 @@ jobs: {%- if not is_libtorch %} - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: !{{ common.upload_artifact_s3_action }} name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -238,6 +189,7 @@ jobs: path: artifacts.zip {%- endif %} + !{{ common.teardown_ec2_linux() }} - name: Hold runner for 2 hours or until ssh sessions have drained # Always hold for active ssh sessions if: always() @@ -258,11 +210,16 @@ jobs: {%- endif %} env: TEST_RUNNER_TYPE: !{{ test_runner_type }} + ENABLE_DISTRIBUTED_TEST: !{{ enable_distributed_test }} ENABLE_JIT_LEGACY_TEST: !{{ enable_jit_legacy_test }} ENABLE_MULTIGPU_TEST: !{{ enable_multigpu_test }} ENABLE_NOGPU_NO_AVX_TEST: !{{ enable_nogpu_no_avx_test }} ENABLE_NOGPU_NO_AVX2_TEST: !{{ enable_nogpu_no_avx2_test }} ENABLE_SLOW_TEST: !{{ enable_slow_test }} + ENABLE_DOCS_TEST: !{{ enable_docs_test }} + ENABLE_BACKWARDS_COMPAT_TEST: !{{ enable_backwards_compat_test }} + ENABLE_XLA_TEST: !{{ enable_xla_test }} + ENABLE_NOARCH_TEST: !{{ enable_noarch_test }} NUM_TEST_SHARDS: !{{ num_test_shards }} MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -277,7 +234,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -295,28 +252,10 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - rm -f ~/.ssh/authorized_keys - - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" - uses: seemethere/add-github-ssh-key@v1 - with: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - submodules: recursive + !{{ common.setup_ec2_linux() }} + !{{ common.checkout_pytorch("recursive") }} - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" @@ -347,12 +286,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: !{{ build_environment }}-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -368,6 +304,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -377,9 +314,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ - -e http_proxy="!{{squid_proxy}}" -e https_proxy="!{{squid_proxy}}" -e no_proxy="!{{squid_no_proxy}}" \ + -e http_proxy="!{{ common.squid_proxy }}" -e https_proxy="!{{ common.squid_proxy }}" -e no_proxy="!{{ common.squid_no_proxy }}" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -419,7 +359,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: !{{ common.upload_artifact_s3_action }} name: Store PyTorch Test Reports on S3 if: always() with: @@ -428,122 +368,27 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af + !{{ common.parse_ref() }} + !{{ common.upload_test_statistics(build_environment) }} + !{{ common.teardown_ec2_linux() }} {% endblock %} {%- endif -%} -{%- if not is_libtorch %} -{% block render_test_results +%} - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, !{{ ciflow_config.root_job_name }}] - if: ${{ needs.test.result != 'skipped' || failure() }} +{%- if enable_doc_jobs %} + pytorch_doc_build: runs-on: linux.2xlarge strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - - name: Display and upload test statistics (Click Me) - # temporary hack: set CIRCLE_* vars, until we update - # tools/stats/print_test_stats.py to natively support GitHub Actions - env: - AWS_DEFAULT_REGION: us-east-1 - CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: !{{ build_environment }}-test - CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} - CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} - CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' - run: | - python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test -{%- endblock %} -{%- endif -%} - {%- if enable_doc_jobs %} - - pytorch_python_doc_build: - runs-on: linux.2xlarge + matrix: + docs_type: [cpp, python] needs: [calculate-docker-image, build, !{{ ciflow_config.root_job_name }}] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + DOCS_TYPE: ${{ matrix.docs_type }} steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - rm -f ~/.ssh/authorized_keys - - name: "[FB EMPLOYEES] Enables SSH (Click me for login details)" - uses: seemethere/add-github-ssh-key@v1 - with: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list - submodules: recursive + !{{ common.setup_ec2_linux() }} + !{{ common.checkout_pytorch("recursive") }} - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b name: Download PyTorch Build Artifacts with: @@ -551,7 +396,7 @@ jobs: - name: Unzip artifacts run: | unzip -o artifacts.zip - - name: Build Python Doc in Docker + - name: Build ${{ matrix.docs_type }} docs run: | set -ex time docker pull "${DOCKER_IMAGE}" > /dev/null @@ -564,6 +409,9 @@ jobs: -e IN_CI \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e CIRCLE_SHA1="$GITHUB_SHA" \ + -e DOCS_VERSION="${target}" \ + -e DOCS_TYPE \ + -e PR_LABELS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -573,42 +421,35 @@ jobs: -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" \ - bash -c "sudo chown -R jenkins . && pip install dist/*.whl && ./.circleci/scripts/python_doc_push_script.sh docs/$target $target site" + bash -c "sudo chown -R jenkins . && pip install dist/*.whl && ./.circleci/scripts/${DOCS_TYPE}_doc_push_script.sh" - name: Chown workspace run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - uses: driazati/upload-artifact-s3@21c31d0a7bcb056ca50bd6ce197ba6507c26a1be - if: ${{ github.event_name == 'pull_request' }} - name: Upload Docs Preview + - uses: !{{ common.upload_artifact_s3_action }} + name: Upload Python Docs Preview + if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'python' }} with: - name: deploy retention-days: 14 if-no-files-found: error - path: pytorch.github.io/docs/merge - - name: Show Docs Preview URL (Click Me) - if: ${{ github.event_name == 'pull_request' }} - env: - PR_NUMBER: ${{ github.event.pull_request.number }} - run: | - echo "See rendered docs at https://docs-preview.pytorch.org/$PR_NUMBER/" + path: pytorch.github.io/docs/merge/ + s3-prefix: ${{ github.repository }}/pr-previews/pr/${{ github.event.pull_request.number }} + - uses: !{{ common.upload_artifact_s3_action }} + name: Upload C++ Docs Preview + if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'cppdocs' }} + with: + retention-days: 14 + if-no-files-found: error + path: cppdocs/ + s3-prefix: ${{ github.repository }}/pr-previews/pr/${{ github.event.pull_request.number }}/cppdocs - name: Archive artifacts into zip run: | - zip -r pytorch_github_io.zip "${GITHUB_WORKSPACE}/pytorch.github.io" + zip -r "docs_${DOCS_TYPE}.zip" "${GITHUB_WORKSPACE}/pytorch.github.io" "${GITHUB_WORKSPACE}/cppdocs" - uses: actions/upload-artifact@v2 name: Store PyTorch Build Artifacts with: - name: pytorch_github_io + name: docs_${{ matrix.docs_type }} + path: docs_${{ matrix.docs_type }}.zip if-no-files-found: error - path: pytorch_github_io.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Prune all of the docker images - docker system prune -af - - {%- endif -%} + !{{ common.teardown_ec2_linux() }} +{%- endif -%} diff --git a/.github/templates/windows_ci_workflow.yml.j2 b/.github/templates/windows_ci_workflow.yml.j2 index c1160fe32de60..20fe72238ffeb 100644 --- a/.github/templates/windows_ci_workflow.yml.j2 +++ b/.github/templates/windows_ci_workflow.yml.j2 @@ -1,7 +1,18 @@ -{# squid_proxy is an private ELB that only available for GHA custom runners #} -{%- set squid_proxy = "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -%} -{# squid_no_proxy is a list of common set of fixed domains or IPs that we don't need to proxy. See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/http_proxy_config.html#windows-proxy #} -{%- set squid_no_proxy = "localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" -%} +{% import 'common.yml.j2' as common %} + +{%- macro wait_and_kill_ssh() -%} + - name: Wait until all sessions have drained + shell: powershell + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 +{%- endmacro -%} # Template is at: .github/templates/windows_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py @@ -36,21 +47,20 @@ env: IN_CI: 1 INSTALL_WINDOWS_SDK: 1 PYTHON_VERSION: "3.8" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} SCCACHE_BUCKET: "ossci-compiler-cache" VC_PRODUCT: "BuildTools" VC_VERSION: "" VS_VERSION: "16.8.6" VC_YEAR: "2019" ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" - no_proxy: !{{ squid_no_proxy }} + no_proxy: !{{ common.squid_no_proxy }} {%- if cuda_version != "cpu" %} TORCH_CUDA_ARCH_LIST: "7.0" USE_CUDA: 1 {%- endif %} -concurrency: - group: !{{ build_environment }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true +!{{ common.concurrency(build_environment) }} jobs: {%- if ciflow_config.enabled %} @@ -72,16 +82,21 @@ jobs: {%- endif %} env: JOB_BASE_NAME: !{{ build_environment }}-build - http_proxy: "!{{ squid_proxy }}" - https_proxy: "!{{ squid_proxy }}" + http_proxy: "!{{ common. squid_proxy }}" + https_proxy: "!{{ common.squid_proxy }}" steps: + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + !{{ common.display_ec2_information() }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -115,12 +130,13 @@ jobs: path: C:\${{ github.run_id }}\build-results - name: Upload artifacts to s3 if: always() - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + uses: !{{ common.upload_artifact_s3_action }} with: retention-days: 14 if-no-files-found: error name: ${{ env.BUILD_ENVIRONMENT }} path: C:\${{ github.run_id }}\build-results + !{{ wait_and_kill_ssh() }} - name: Cleanup build-results and workspaces if: always() shell: bash @@ -156,7 +172,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -170,10 +186,11 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} TEST_CONFIG: ${{ matrix.config }} - http_proxy: "!{{ squid_proxy }}" - https_proxy: "!{{ squid_proxy }}" + http_proxy: "!{{ common.squid_proxy }}" + https_proxy: "!{{ common.squid_proxy }}" RUN_SMOKE_TESTS_ONLY_ON_PR: !{{ only_run_smoke_tests_on_pull_request }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} needs: [build, generate-test-matrix, !{{ ciflow_config.root_job_name }}] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} @@ -184,12 +201,17 @@ jobs: working-directory: pytorch-${{ github.run_id }} steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + !{{ common.display_ec2_information() }} + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -248,77 +270,12 @@ jobs: if-no-files-found: error path: pytorch-${{ github.run_id }}/test-reports-*.zip + !{{ wait_and_kill_ssh() }} + !{{ common.parse_ref() }} + !{{ common.upload_test_statistics(build_environment) }} - name: Cleanup workspace if: always() shell: bash # Should remove the entirety of pytorch-${{ github.run_id }} run: | rm -rf ./* - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, !{{ ciflow_config.root_job_name }}] -{%- if only_build_on_pull_request %} - if: ${{ github.event_name == 'push' && (needs.test.result != 'skipped' || failure()) }} -{%- else %} - if: ${{ needs.test.result != 'skipped' || failure() }} -{%- endif %} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - # TODO: Make this into a composite step - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - - name: Display and upload test statistics (Click Me) - # temporary hack: set CIRCLE_* vars, until we update - # tools/stats/print_test_stats.py to natively support GitHub Actions - env: - AWS_DEFAULT_REGION: us-east-1 - CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: !{{ build_environment }}-test - CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} - CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} - CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' - run: | - python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test diff --git a/.github/workflows/add_annotations.yml b/.github/workflows/add_annotations.yml index 40c2677aaf80d..76f7307e3fb77 100644 --- a/.github/workflows/add_annotations.yml +++ b/.github/workflows/add_annotations.yml @@ -7,6 +7,12 @@ on: workflows: - Lint + +concurrency: + group: add-annotations-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + + jobs: annotate: if: ${{ github.repository_owner == 'pytorch' }} diff --git a/.github/workflows/auto_label.yml b/.github/workflows/auto_label.yml index 24fc02eff1439..6dcb29a70f57a 100644 --- a/.github/workflows/auto_label.yml +++ b/.github/workflows/auto_label.yml @@ -6,6 +6,12 @@ on: pull_request_target: types: [edited, opened, synchronize, reopened] + +concurrency: + group: auto-label-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + + jobs: auto-label-rocm: if: ${{ github.repository_owner == 'pytorch' }} diff --git a/.github/workflows/build_linux_conda.yml b/.github/workflows/build_linux_conda.yml index 2037f0c1cf561..b43c2013327ba 100644 --- a/.github/workflows/build_linux_conda.yml +++ b/.github/workflows/build_linux_conda.yml @@ -16,7 +16,7 @@ jobs: image: python:3.9 steps: - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating build matrix id: set-matrix run: | @@ -57,12 +57,12 @@ jobs: - name: Clean runner workspace run: rm -rf "$GITHUB_WORKSPACE" - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: path: pytorch submodules: recursive - name: Clone pytorch/builder - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: repository: pytorch/builder path: builder @@ -111,5 +111,5 @@ jobs: python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 concurrency: - group: build-linux-conda-${{ github.event.pull_request.number || github.sha }} + group: build-linux-conda-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/build_linux_libtorch.yml b/.github/workflows/build_linux_libtorch.yml index 9d4964a8594b1..0a1c653375f9c 100644 --- a/.github/workflows/build_linux_libtorch.yml +++ b/.github/workflows/build_linux_libtorch.yml @@ -16,7 +16,7 @@ jobs: image: python:3.9 steps: - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating build matrix id: set-matrix run: | @@ -51,12 +51,12 @@ jobs: - name: Clean runner workspace run: rm -rf "$GITHUB_WORKSPACE" - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: path: pytorch submodules: recursive - name: Clone pytorch/builder - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: repository: pytorch/builder path: builder @@ -110,5 +110,5 @@ jobs: python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 concurrency: - group: build-linux-libtorch-${{ github.event.pull_request.number || github.sha }} + group: build-linux-libtorch-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/build_linux_wheels.yml b/.github/workflows/build_linux_wheels.yml index c32eee6892033..1f8e5f02e2220 100644 --- a/.github/workflows/build_linux_wheels.yml +++ b/.github/workflows/build_linux_wheels.yml @@ -16,7 +16,7 @@ jobs: image: python:3.9 steps: - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating build matrix id: set-matrix run: | @@ -46,12 +46,12 @@ jobs: - name: Clean runner workspace run: rm -rf "$GITHUB_WORKSPACE" - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: path: pytorch submodules: recursive - name: Clone pytorch/builder - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: repository: pytorch/builder path: builder @@ -109,5 +109,5 @@ jobs: python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 concurrency: - group: build-linux-wheels-${{ github.event.pull_request.number || github.sha }} + group: build-linux-wheels-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index fa65168a4709c..eea423c00505c 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -14,7 +14,7 @@ jobs: name: Create Release runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: 'recursive' - name: Fake name for PRs @@ -48,5 +48,5 @@ jobs: files: ${{env.PT_RELEASE_FILE}} concurrency: - group: create-release-${{ github.event.pull_request.number || github.sha }} + group: create-release-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml b/.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml index a783b9b1886ec..477fe1bac6fe2 100644 --- a/.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml +++ b/.github/workflows/generated-libtorch-linux-xenial-cuda10.2-py3.6-gcc7.yml @@ -4,7 +4,8 @@ name: libtorch-linux-xenial-cuda10.2-py3.6-gcc7 on: - # TODO: Enable pull_request builds when we can verify capacity can be met by auto-scalers + pull_request: + types: [unassigned] push: branches: - master @@ -15,6 +16,7 @@ env: BUILD_ENVIRONMENT: libtorch-linux-xenial-cuda10.2-py3.6-gcc7 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -22,27 +24,54 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: libtorch-linux-xenial-cuda10.2-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: libtorch-linux-xenial-cuda10.2-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/libtorch') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge + needs: [ciflow_should_run] env: DOCKER_BUILDKIT: 1 timeout-minutes: 90 outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -50,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -104,17 +141,32 @@ jobs: build: runs-on: linux.2xlarge - needs: [calculate-docker-image, ] + needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: libtorch-linux-xenial-cuda10.2-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -126,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -144,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -184,6 +239,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | diff --git a/.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml b/.github/workflows/generated-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml similarity index 68% rename from .github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml rename to .github/workflows/generated-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml index 4aa29abb09d6d..9fd6d7ff8d140 100644 --- a/.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml +++ b/.github/workflows/generated-libtorch-linux-xenial-cuda11.3-py3.6-gcc7.yml @@ -1,19 +1,22 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/linux_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7 +name: libtorch-linux-xenial-cuda11.3-py3.6-gcc7 on: pull_request: types: [unassigned] - schedule: - - cron: 45 0,4,8,12,16,20 * * * + push: + branches: + - master + - release/* workflow_dispatch: env: - BUILD_ENVIRONMENT: periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7 + BUILD_ENVIRONMENT: libtorch-linux-xenial-cuda11.3-py3.6-gcc7 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -21,18 +24,23 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: libtorch-linux-xenial-cuda11.3-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/libtorch') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge @@ -43,12 +51,27 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -56,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -113,14 +144,29 @@ jobs: needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: periodic-libtorch-linux-xenial-cuda11.3-py3.6-gcc7-build + JOB_BASE_NAME: libtorch-linux-xenial-cuda11.3-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -132,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -150,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -190,6 +239,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | diff --git a/.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml b/.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml index a9011b7047832..ee0ca4cf76ce0 100644 --- a/.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml +++ b/.github/workflows/generated-linux-bionic-cuda10.2-py3.9-gcc7.yml @@ -4,7 +4,8 @@ name: linux-bionic-cuda10.2-py3.9-gcc7 on: - # TODO: Enable pull_request builds when we can verify capacity can be met by auto-scalers + pull_request: + types: [unassigned] push: branches: - master @@ -15,6 +16,7 @@ env: BUILD_ENVIRONMENT: linux-bionic-cuda10.2-py3.9-gcc7 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -22,27 +24,54 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-bionic-cuda10.2-py3.9-gcc7-${{ github.event.pull_request.number || github.sha }} + group: linux-bionic-cuda10.2-py3.9-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') || contains(github.event.pull_request.labels.*.name, 'ciflow/slow')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge + needs: [ciflow_should_run] env: DOCKER_BUILDKIT: 1 timeout-minutes: 90 outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -50,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -104,17 +141,32 @@ jobs: build: runs-on: linux.2xlarge - needs: [calculate-docker-image, ] + needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: linux-bionic-cuda10.2-py3.9-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -126,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -144,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -182,19 +237,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -206,6 +250,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -215,13 +277,19 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} runs-on: ubuntu-18.04 + needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.8xlarge.nvidia.gpu + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: '' ENABLE_MULTIGPU_TEST: '' ENABLE_NOGPU_NO_AVX_TEST: '' ENABLE_NOGPU_NO_AVX2_TEST: '' ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -236,13 +304,13 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py test: - needs: [calculate-docker-image, build, generate-test-matrix, ] + needs: [calculate-docker-image, build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false @@ -254,16 +322,32 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -272,9 +356,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -306,12 +395,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: linux-bionic-cuda10.2-py3.9-gcc7-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -327,6 +413,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -336,9 +423,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -372,7 +462,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -381,66 +471,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -455,5 +485,26 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-bionic-py3.6-clang9.yml b/.github/workflows/generated-linux-bionic-py3.6-clang9.yml new file mode 100644 index 0000000000000..3aedb76b3e665 --- /dev/null +++ b/.github/workflows/generated-linux-bionic-py3.6-clang9.yml @@ -0,0 +1,510 @@ +# @generated DO NOT EDIT MANUALLY +# Template is at: .github/templates/linux_ci_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-bionic-py3.6-clang9 + +on: + pull_request: + types: [opened, synchronize, reopened, unassigned] + push: + branches: + - master + - release/* + workflow_dispatch: + +env: + BUILD_ENVIRONMENT: linux-bionic-py3.6-clang9 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.6-clang9 + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + TORCH_CUDA_ARCH_LIST: 5.2 + IN_CI: 1 + # This is used for the phase of adding wheel tests only, will be removed once completed + IN_WHEEL_TEST: 1 + # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh + CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + +concurrency: + group: linux-bionic-py3.6-clang9-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') || contains(github.event.pull_request.labels.*.name, 'ciflow/noarch') || contains(github.event.pull_request.labels.*.name, 'ciflow/xla')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" + calculate-docker-image: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: linux.2xlarge + needs: [ciflow_should_run] + env: + DOCKER_BUILDKIT: 1 + timeout-minutes: 90 + outputs: + docker_image: ${{ steps.calculate-tag.outputs.docker_image }} + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: false + - name: Calculate docker image tag + id: calculate-tag + run: | + DOCKER_TAG=$(git rev-parse HEAD:.circleci/docker) + echo "::set-output name=docker_tag::${DOCKER_TAG}" + echo "::set-output name=docker_image::${DOCKER_IMAGE_BASE}:${DOCKER_TAG}" + - name: Check if image should be built + id: check + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + BASE_REVISION: ${{ github.event.pull_request.base.sha || github.sha }} + run: | + set -x + # Check if image already exists, if it does then skip building it + if docker manifest inspect "${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"; then + exit 0 + fi + if [[ "$BASE_REVISION" = "$(git rev-parse HEAD)" ]]; then + # if we're on the base branch then use the parent commit + MERGE_BASE=$(git rev-parse HEAD~) + else + # otherwise we're on a PR, so use the most recent base commit + MERGE_BASE=$(git merge-base HEAD "$BASE_REVISION") + fi + # Covers the case where a previous tag doesn't exist for the tree + # this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly + if ! git rev-parse "$MERGE_BASE:.circleci/docker"; then + echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit" + exit 1 + fi + PREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker") + # If no image exists but the hash is the same as the previous hash then we should error out here + if [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then + echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch" + echo " contact the PyTorch team to restore the original images" + exit 1 + fi + echo ::set-output name=rebuild::yes + - name: Build and push docker image + if: ${{ steps.check.outputs.rebuild }} + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + DOCKER_SKIP_S3_UPLOAD: 1 + run: | + export IMAGE_NAME=${DOCKER_IMAGE_BASE#308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/} + cd .circleci/docker && ./build_docker.sh + + build: + runs-on: linux.2xlarge + needs: [calculate-docker-image, ciflow_should_run] + env: + DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + JOB_BASE_NAME: linux-bionic-py3.6-clang9-build + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: recursive + - name: Pull docker image + run: | + docker pull "${DOCKER_IMAGE}" + - name: Build PyTorch + run: | + docker run \ + -e BUILD_ENVIRONMENT \ + -e JOB_BASE_NAME \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e SKIP_SCCACHE_INITIALIZATION=1 \ + -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ + -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --tty \ + --user jenkins \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" \ + sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh' + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py + - name: Display and upload binary build size statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + run: | + COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0) + export COMMIT_TIME + pip3 install requests + python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 + - name: Chown workspace + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Archive artifacts into zip + run: | + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 + name: Store PyTorch Build Artifacts on S3 + with: + name: ${{ env.BUILD_ENVIRONMENT }} + retention-days: 14 + if-no-files-found: error + path: + artifacts.zip + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Clean up docker images + if: always() + run: | + # Prune all of the docker images + docker system prune -af + + generate-test-matrix: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: ubuntu-18.04 + needs: [ciflow_should_run] + env: + TEST_RUNNER_TYPE: linux.2xlarge + ENABLE_DISTRIBUTED_TEST: '' + ENABLE_JIT_LEGACY_TEST: '' + ENABLE_MULTIGPU_TEST: '' + ENABLE_NOGPU_NO_AVX_TEST: '' + ENABLE_NOGPU_NO_AVX2_TEST: '' + ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: 1 + NUM_TEST_SHARDS: 2 + MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu + NOGPU_RUNNER_TYPE: linux.2xlarge + PR_BODY: ${{ github.event.pull_request.body }} + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + render-matrix: ${{ steps.set-matrix.outputs.render-matrix }} + ignore-disabled-issues: ${{ steps.set-matrix.outputs.ignore-disabled-issues }} + container: + image: python:3.9 + steps: + - name: Install dependencies + run: pip install typing-extensions + - name: Clone pytorch/pytorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + - name: Generating test matrix + id: set-matrix + run: .github/scripts/generate_pytorch_test_matrix.py + + test: + needs: [calculate-docker-image, build, generate-test-matrix, ciflow_should_run] + strategy: + matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} + fail-fast: false + runs-on: ${{ matrix.runner }} + env: + DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + JOB_BASE_NAME: linux-bionic-py3.6-clang9-test + TEST_CONFIG: ${{ matrix.config }} + SHARD_NUMBER: ${{ matrix.shard }} + NUM_TEST_SHARDS: ${{ matrix.num_shards }} + PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: recursive + - name: Pull docker image + run: | + docker pull "${DOCKER_IMAGE}" + - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG + if: ${{ contains(env.BUILD_ENVIRONMENT, 'cuda') && !contains(matrix.config, 'nogpu') }} + run: | + bash .github/scripts/install_nvidia_utils_linux.sh + echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" + - name: Determine shm-size + run: | + shm_size="1g" + case "${BUILD_ENVIRONMENT}" in + *cuda*) + shm_size="2g" + ;; + *rocm*) + shm_size="8g" + ;; + esac + echo "SHM_SIZE=${shm_size}" >> "${GITHUB_ENV}" + - uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b + name: Download PyTorch Build Artifacts + with: + name: ${{ env.BUILD_ENVIRONMENT }} + - name: Unzip artifacts + run: | + unzip -o artifacts.zip + - name: Output disk space left + run: | + sudo df -H + - name: Test PyTorch + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + if [[ $TEST_CONFIG == 'multigpu' ]]; then + TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh + else + TEST_COMMAND=.jenkins/pytorch/test.sh + fi + if [[ $NUM_TEST_SHARDS -ne 2 ]]; then + export SHARD_NUMBER=0 + fi + # TODO: Stop building test binaries as part of the build phase + # Used for GPU_FLAG since that doesn't play nice + # shellcheck disable=SC2086 + docker run \ + ${GPU_FLAG:-} \ + -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ + -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e GITHUB_ACTIONS \ + -e IN_CI \ + -e IN_WHEEL_TEST \ + -e SHARD_NUMBER \ + -e JOB_BASE_NAME \ + -e TEST_CONFIG \ + -e NUM_TEST_SHARDS \ + -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e SCCACHE_BUCKET \ + -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --shm-size="${SHM_SIZE}" \ + --tty \ + --user jenkins \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" \ + sh -c 'sudo chown -R jenkins . && pip install dist/*.whl && '$TEST_COMMAND + - name: Chown workspace + if: always() + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Zip test reports for upload + if: always() + env: + COMMIT_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + run: | + # Remove any previous test reports if they exist + rm -f test-reports-*.zip + zip -r "test-reports-${COMMIT_SHA1}-${WORKFLOW_ID}.zip" test -i '*.xml' + - uses: actions/upload-artifact@v2 + name: Store PyTorch Test Reports + if: always() + with: + name: test-reports-${{ matrix.config }} + retention-days: 14 + if-no-files-found: error + path: + test-reports-*.zip + - uses: seemethere/upload-artifact-s3@v3 + name: Store PyTorch Test Reports on S3 + if: always() + with: + name: test-reports-${{ matrix.config }} + retention-days: 14 + if-no-files-found: error + path: + test-reports-*.zip + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py + - name: Display and upload test statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + JOB_BASE_NAME: linux-bionic-py3.6-clang9-test + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash + run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 + python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml b/.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml index 3663c591ab806..2103f2b66bdbf 100644 --- a/.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml +++ b/.github/workflows/generated-linux-bionic-py3.8-gcc9-coverage.yml @@ -16,6 +16,7 @@ env: BUILD_ENVIRONMENT: linux-bionic-py3.8-gcc9-coverage DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -23,18 +24,23 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-bionic-py3.8-gcc9-coverage-${{ github.event.pull_request.number || github.sha }} + group: linux-bionic-py3.8-gcc9-coverage-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/default')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/coverage') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge @@ -45,12 +51,27 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -58,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -117,12 +146,27 @@ jobs: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: linux-bionic-py3.8-gcc9-coverage-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -134,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -152,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -190,19 +237,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -214,6 +250,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -226,11 +280,16 @@ jobs: needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.2xlarge + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: '' ENABLE_MULTIGPU_TEST: '' ENABLE_NOGPU_NO_AVX_TEST: '' ENABLE_NOGPU_NO_AVX2_TEST: '' ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -245,7 +304,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -263,16 +322,32 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -281,9 +356,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -315,12 +395,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: linux-bionic-py3.8-gcc9-coverage-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -336,6 +413,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -345,9 +423,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -385,7 +466,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -394,66 +475,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ciflow_should_run] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -468,5 +489,26 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml b/.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml index 38fe8593fc3eb..187f9c1ccfdfb 100644 --- a/.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml +++ b/.github/workflows/generated-linux-xenial-cuda10.2-py3.6-gcc7.yml @@ -16,6 +16,7 @@ env: BUILD_ENVIRONMENT: linux-xenial-cuda10.2-py3.6-gcc7 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -23,18 +24,23 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-xenial-cuda10.2-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: linux-xenial-cuda10.2-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/slow')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') || contains(github.event.pull_request.labels.*.name, 'ciflow/slow')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge @@ -45,12 +51,27 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -58,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -117,12 +146,27 @@ jobs: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: linux-xenial-cuda10.2-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -134,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -152,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -190,19 +237,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -214,6 +250,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -226,11 +280,16 @@ jobs: needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.8xlarge.nvidia.gpu + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: 1 ENABLE_MULTIGPU_TEST: 1 ENABLE_NOGPU_NO_AVX_TEST: 1 ENABLE_NOGPU_NO_AVX2_TEST: 1 ENABLE_SLOW_TEST: 1 + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -245,7 +304,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -263,16 +322,32 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -281,9 +356,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -315,12 +395,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: linux-xenial-cuda10.2-py3.6-gcc7-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -336,6 +413,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -345,9 +423,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -381,7 +462,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -390,66 +471,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ciflow_should_run] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -464,5 +485,26 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-xenial-cuda11.1-py3.6-gcc7.yml b/.github/workflows/generated-linux-xenial-cuda11.3-py3.6-gcc7.yml similarity index 72% rename from .github/workflows/generated-linux-xenial-cuda11.1-py3.6-gcc7.yml rename to .github/workflows/generated-linux-xenial-cuda11.3-py3.6-gcc7.yml index a5f0488644596..9fff700c56e84 100644 --- a/.github/workflows/generated-linux-xenial-cuda11.1-py3.6-gcc7.yml +++ b/.github/workflows/generated-linux-xenial-cuda11.3-py3.6-gcc7.yml @@ -1,10 +1,11 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/linux_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: linux-xenial-cuda11.1-py3.6-gcc7 +name: linux-xenial-cuda11.3-py3.6-gcc7 on: - # TODO: Enable pull_request builds when we can verify capacity can be met by auto-scalers + pull_request: + types: [opened, synchronize, reopened, unassigned] push: branches: - master @@ -12,9 +13,10 @@ on: workflow_dispatch: env: - BUILD_ENVIRONMENT: linux-xenial-cuda11.1-py3.6-gcc7 - DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7 + BUILD_ENVIRONMENT: linux-xenial-cuda11.3-py3.6-gcc7 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -22,27 +24,54 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-xenial-cuda11.1-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: linux-xenial-cuda11.3-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge + needs: [ciflow_should_run] env: DOCKER_BUILDKIT: 1 timeout-minutes: 90 outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -50,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -104,17 +141,32 @@ jobs: build: runs-on: linux.2xlarge - needs: [calculate-docker-image, ] + needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: linux-xenial-cuda11.1-py3.6-gcc7-build + JOB_BASE_NAME: linux-xenial-cuda11.3-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -126,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -144,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -182,19 +237,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -206,6 +250,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -215,13 +277,19 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} runs-on: ubuntu-18.04 + needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.8xlarge.nvidia.gpu + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: '' ENABLE_MULTIGPU_TEST: '' ENABLE_NOGPU_NO_AVX_TEST: '' ENABLE_NOGPU_NO_AVX2_TEST: '' ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -236,34 +304,50 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py test: - needs: [calculate-docker-image, build, generate-test-matrix, ] + needs: [calculate-docker-image, build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false runs-on: ${{ matrix.runner }} env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: linux-xenial-cuda11.1-py3.6-gcc7-test + JOB_BASE_NAME: linux-xenial-cuda11.3-py3.6-gcc7-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -272,9 +356,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -306,12 +395,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: linux-xenial-cuda11.1-py3.6-gcc7-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -327,6 +413,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -336,9 +423,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -372,7 +462,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -381,66 +471,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -450,10 +480,31 @@ jobs: env: AWS_DEFAULT_REGION: us-east-1 CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: linux-xenial-cuda11.1-py3.6-gcc7-test + JOB_BASE_NAME: linux-xenial-cuda11.3-py3.6-gcc7-test CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml b/.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml index 14e22d85edc26..d1187de624f17 100644 --- a/.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml +++ b/.github/workflows/generated-linux-xenial-py3.6-gcc5.4.yml @@ -5,6 +5,7 @@ name: linux-xenial-py3.6-gcc5.4 on: pull_request: + types: [opened, synchronize, reopened, unassigned] push: branches: - master @@ -15,6 +16,7 @@ env: BUILD_ENVIRONMENT: linux-xenial-py3.6-gcc5.4 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -22,27 +24,54 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-xenial-py3.6-gcc5.4-${{ github.event.pull_request.number || github.sha }} + group: linux-xenial-py3.6-gcc5.4-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge + needs: [ciflow_should_run] env: DOCKER_BUILDKIT: 1 timeout-minutes: 90 outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -50,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -104,17 +141,32 @@ jobs: build: runs-on: linux.2xlarge - needs: [calculate-docker-image, ] + needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: linux-xenial-py3.6-gcc5.4-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -126,17 +178,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -144,9 +197,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -182,19 +237,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -206,6 +250,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -215,13 +277,19 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} runs-on: ubuntu-18.04 + needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.2xlarge + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: '' ENABLE_MULTIGPU_TEST: '' ENABLE_NOGPU_NO_AVX_TEST: '' ENABLE_NOGPU_NO_AVX2_TEST: '' ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: 1 + ENABLE_BACKWARDS_COMPAT_TEST: 1 + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -236,13 +304,13 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py test: - needs: [calculate-docker-image, build, generate-test-matrix, ] + needs: [calculate-docker-image, build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false @@ -254,16 +322,32 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -272,9 +356,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -306,12 +395,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: linux-xenial-py3.6-gcc5.4-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -327,6 +413,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -336,9 +423,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -372,7 +462,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -381,66 +471,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -455,21 +485,61 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - pytorch_python_doc_build: + pytorch_doc_build: runs-on: linux.2xlarge - needs: [calculate-docker-image, build, ] + strategy: + matrix: + docs_type: [cpp, python] + needs: [calculate-docker-image, build, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + DOCS_TYPE: ${{ matrix.docs_type }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -477,21 +547,22 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys - - name: "[FB EMPLOYEES] Enables SSH (Click me for login details)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b name: Download PyTorch Build Artifacts with: @@ -499,7 +570,7 @@ jobs: - name: Unzip artifacts run: | unzip -o artifacts.zip - - name: Build Python Doc in Docker + - name: Build ${{ matrix.docs_type }} docs run: | set -ex time docker pull "${DOCKER_IMAGE}" > /dev/null @@ -512,6 +583,9 @@ jobs: -e IN_CI \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e CIRCLE_SHA1="$GITHUB_SHA" \ + -e DOCS_VERSION="${target}" \ + -e DOCS_TYPE \ + -e PR_LABELS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -521,40 +595,51 @@ jobs: -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" \ - bash -c "sudo chown -R jenkins . && pip install dist/*.whl && ./.circleci/scripts/python_doc_push_script.sh docs/$target $target site" + bash -c "sudo chown -R jenkins . && pip install dist/*.whl && ./.circleci/scripts/${DOCS_TYPE}_doc_push_script.sh" - name: Chown workspace run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - uses: driazati/upload-artifact-s3@21c31d0a7bcb056ca50bd6ce197ba6507c26a1be - if: ${{ github.event_name == 'pull_request' }} - name: Upload Docs Preview + - uses: seemethere/upload-artifact-s3@v3 + name: Upload Python Docs Preview + if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'python' }} with: - name: deploy retention-days: 14 if-no-files-found: error - path: pytorch.github.io/docs/merge - - name: Show Docs Preview URL (Click Me) - if: ${{ github.event_name == 'pull_request' }} - env: - PR_NUMBER: ${{ github.event.pull_request.number }} - run: | - echo "See rendered docs at https://docs-preview.pytorch.org/$PR_NUMBER/" + path: pytorch.github.io/docs/merge/ + s3-prefix: ${{ github.repository }}/pr-previews/pr/${{ github.event.pull_request.number }} + - uses: seemethere/upload-artifact-s3@v3 + name: Upload C++ Docs Preview + if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'cppdocs' }} + with: + retention-days: 14 + if-no-files-found: error + path: cppdocs/ + s3-prefix: ${{ github.repository }}/pr-previews/pr/${{ github.event.pull_request.number }}/cppdocs - name: Archive artifacts into zip run: | - zip -r pytorch_github_io.zip "${GITHUB_WORKSPACE}/pytorch.github.io" + zip -r "docs_${DOCS_TYPE}.zip" "${GITHUB_WORKSPACE}/pytorch.github.io" "${GITHUB_WORKSPACE}/cppdocs" - uses: actions/upload-artifact@v2 name: Store PyTorch Build Artifacts with: - name: pytorch_github_io + name: docs_${{ matrix.docs_type }} + path: docs_${{ matrix.docs_type }}.zip if-no-files-found: error - path: pytorch_github_io.zip - name: Hold runner for 2 hours or until ssh sessions have drained # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images + - name: Kill containers, clean up images if: always() run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-linux-xenial-py3.6-gcc7-bazel-test.yml b/.github/workflows/generated-linux-xenial-py3.6-gcc7-bazel-test.yml index 87c6df024b6e4..49d2cd2f2267c 100644 --- a/.github/workflows/generated-linux-xenial-py3.6-gcc7-bazel-test.yml +++ b/.github/workflows/generated-linux-xenial-py3.6-gcc7-bazel-test.yml @@ -14,8 +14,9 @@ on: env: BUILD_ENVIRONMENT: linux-xenial-py3.6-gcc7-bazel-test - DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -23,18 +24,23 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: linux-xenial-py3.6-gcc7-bazel-test-${{ github.event.pull_request.number || github.sha }} + group: linux-xenial-py3.6-gcc7-bazel-test-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/default')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/bazel') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge @@ -45,12 +51,27 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -58,11 +79,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -118,20 +147,48 @@ jobs: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} JOB_BASE_NAME: linux-xenial-py3.6-gcc7-bazel-test-build-and-test NUM_TEST_SHARDS: 1 + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -162,6 +219,7 @@ jobs: -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e PR_LABELS \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ @@ -211,6 +269,8 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e CONTINUE_THROUGH_ERROR \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -245,69 +305,37 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Clean up docker images - if: always() - run: | - # Prune all of the docker images - docker system prune -af - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [build-and-test, ciflow_should_run] - if: ${{ needs.build-and-test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - name: Display and upload test statistics (Click Me) # temporary hack: set CIRCLE_* vars, until we update # tools/stats/print_test_stats.py to natively support GitHub Actions env: AWS_DEFAULT_REGION: us-east-1 CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: linux-xenial-py3.6-gcc7-bazel-test-build-and-test + JOB_BASE_NAME: linux-xenial-py3.6-gcc7-bazel-test-test CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-parallelnative-linux-xenial-py3.6-gcc5.4.yml b/.github/workflows/generated-parallelnative-linux-xenial-py3.6-gcc5.4.yml new file mode 100644 index 0000000000000..1b352f6b8cd80 --- /dev/null +++ b/.github/workflows/generated-parallelnative-linux-xenial-py3.6-gcc5.4.yml @@ -0,0 +1,510 @@ +# @generated DO NOT EDIT MANUALLY +# Template is at: .github/templates/linux_ci_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: parallelnative-linux-xenial-py3.6-gcc5.4 + +on: + pull_request: + types: [unassigned] + push: + branches: + - master + - release/* + workflow_dispatch: + +env: + BUILD_ENVIRONMENT: parallelnative-linux-xenial-py3.6-gcc5.4 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + TORCH_CUDA_ARCH_LIST: 5.2 + IN_CI: 1 + # This is used for the phase of adding wheel tests only, will be removed once completed + IN_WHEEL_TEST: 1 + # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh + CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + +concurrency: + group: parallelnative-linux-xenial-py3.6-gcc5.4-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" + calculate-docker-image: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: linux.2xlarge + needs: [ciflow_should_run] + env: + DOCKER_BUILDKIT: 1 + timeout-minutes: 90 + outputs: + docker_image: ${{ steps.calculate-tag.outputs.docker_image }} + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: false + - name: Calculate docker image tag + id: calculate-tag + run: | + DOCKER_TAG=$(git rev-parse HEAD:.circleci/docker) + echo "::set-output name=docker_tag::${DOCKER_TAG}" + echo "::set-output name=docker_image::${DOCKER_IMAGE_BASE}:${DOCKER_TAG}" + - name: Check if image should be built + id: check + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + BASE_REVISION: ${{ github.event.pull_request.base.sha || github.sha }} + run: | + set -x + # Check if image already exists, if it does then skip building it + if docker manifest inspect "${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"; then + exit 0 + fi + if [[ "$BASE_REVISION" = "$(git rev-parse HEAD)" ]]; then + # if we're on the base branch then use the parent commit + MERGE_BASE=$(git rev-parse HEAD~) + else + # otherwise we're on a PR, so use the most recent base commit + MERGE_BASE=$(git merge-base HEAD "$BASE_REVISION") + fi + # Covers the case where a previous tag doesn't exist for the tree + # this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly + if ! git rev-parse "$MERGE_BASE:.circleci/docker"; then + echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit" + exit 1 + fi + PREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker") + # If no image exists but the hash is the same as the previous hash then we should error out here + if [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then + echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch" + echo " contact the PyTorch team to restore the original images" + exit 1 + fi + echo ::set-output name=rebuild::yes + - name: Build and push docker image + if: ${{ steps.check.outputs.rebuild }} + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + DOCKER_SKIP_S3_UPLOAD: 1 + run: | + export IMAGE_NAME=${DOCKER_IMAGE_BASE#308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/} + cd .circleci/docker && ./build_docker.sh + + build: + runs-on: linux.2xlarge + needs: [calculate-docker-image, ciflow_should_run] + env: + DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + JOB_BASE_NAME: parallelnative-linux-xenial-py3.6-gcc5.4-build + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: recursive + - name: Pull docker image + run: | + docker pull "${DOCKER_IMAGE}" + - name: Build PyTorch + run: | + docker run \ + -e BUILD_ENVIRONMENT \ + -e JOB_BASE_NAME \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e SKIP_SCCACHE_INITIALIZATION=1 \ + -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ + -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --tty \ + --user jenkins \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" \ + sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh' + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py + - name: Display and upload binary build size statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + run: | + COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0) + export COMMIT_TIME + pip3 install requests + python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 + - name: Chown workspace + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Archive artifacts into zip + run: | + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 + name: Store PyTorch Build Artifacts on S3 + with: + name: ${{ env.BUILD_ENVIRONMENT }} + retention-days: 14 + if-no-files-found: error + path: + artifacts.zip + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Clean up docker images + if: always() + run: | + # Prune all of the docker images + docker system prune -af + + generate-test-matrix: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: ubuntu-18.04 + needs: [ciflow_should_run] + env: + TEST_RUNNER_TYPE: linux.2xlarge + ENABLE_DISTRIBUTED_TEST: 1 + ENABLE_JIT_LEGACY_TEST: '' + ENABLE_MULTIGPU_TEST: '' + ENABLE_NOGPU_NO_AVX_TEST: '' + ENABLE_NOGPU_NO_AVX2_TEST: '' + ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' + NUM_TEST_SHARDS: 1 + MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu + NOGPU_RUNNER_TYPE: linux.2xlarge + PR_BODY: ${{ github.event.pull_request.body }} + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + render-matrix: ${{ steps.set-matrix.outputs.render-matrix }} + ignore-disabled-issues: ${{ steps.set-matrix.outputs.ignore-disabled-issues }} + container: + image: python:3.9 + steps: + - name: Install dependencies + run: pip install typing-extensions + - name: Clone pytorch/pytorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + - name: Generating test matrix + id: set-matrix + run: .github/scripts/generate_pytorch_test_matrix.py + + test: + needs: [calculate-docker-image, build, generate-test-matrix, ciflow_should_run] + strategy: + matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} + fail-fast: false + runs-on: ${{ matrix.runner }} + env: + DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + JOB_BASE_NAME: parallelnative-linux-xenial-py3.6-gcc5.4-test + TEST_CONFIG: ${{ matrix.config }} + SHARD_NUMBER: ${{ matrix.shard }} + NUM_TEST_SHARDS: ${{ matrix.num_shards }} + PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: recursive + - name: Pull docker image + run: | + docker pull "${DOCKER_IMAGE}" + - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG + if: ${{ contains(env.BUILD_ENVIRONMENT, 'cuda') && !contains(matrix.config, 'nogpu') }} + run: | + bash .github/scripts/install_nvidia_utils_linux.sh + echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" + - name: Determine shm-size + run: | + shm_size="1g" + case "${BUILD_ENVIRONMENT}" in + *cuda*) + shm_size="2g" + ;; + *rocm*) + shm_size="8g" + ;; + esac + echo "SHM_SIZE=${shm_size}" >> "${GITHUB_ENV}" + - uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b + name: Download PyTorch Build Artifacts + with: + name: ${{ env.BUILD_ENVIRONMENT }} + - name: Unzip artifacts + run: | + unzip -o artifacts.zip + - name: Output disk space left + run: | + sudo df -H + - name: Test PyTorch + env: + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + if [[ $TEST_CONFIG == 'multigpu' ]]; then + TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh + else + TEST_COMMAND=.jenkins/pytorch/test.sh + fi + if [[ $NUM_TEST_SHARDS -ne 2 ]]; then + export SHARD_NUMBER=0 + fi + # TODO: Stop building test binaries as part of the build phase + # Used for GPU_FLAG since that doesn't play nice + # shellcheck disable=SC2086 + docker run \ + ${GPU_FLAG:-} \ + -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ + -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e GITHUB_ACTIONS \ + -e IN_CI \ + -e IN_WHEEL_TEST \ + -e SHARD_NUMBER \ + -e JOB_BASE_NAME \ + -e TEST_CONFIG \ + -e NUM_TEST_SHARDS \ + -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e SCCACHE_BUCKET \ + -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --shm-size="${SHM_SIZE}" \ + --tty \ + --user jenkins \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" \ + sh -c 'sudo chown -R jenkins . && pip install dist/*.whl && '$TEST_COMMAND + - name: Chown workspace + if: always() + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Zip test reports for upload + if: always() + env: + COMMIT_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + run: | + # Remove any previous test reports if they exist + rm -f test-reports-*.zip + zip -r "test-reports-${COMMIT_SHA1}-${WORKFLOW_ID}.zip" test -i '*.xml' + - uses: actions/upload-artifact@v2 + name: Store PyTorch Test Reports + if: always() + with: + name: test-reports-${{ matrix.config }} + retention-days: 14 + if-no-files-found: error + path: + test-reports-*.zip + - uses: seemethere/upload-artifact-s3@v3 + name: Store PyTorch Test Reports on S3 + if: always() + with: + name: test-reports-${{ matrix.config }} + retention-days: 14 + if-no-files-found: error + path: + test-reports-*.zip + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py + - name: Display and upload test statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + JOB_BASE_NAME: parallelnative-linux-xenial-py3.6-gcc5.4-test + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash + run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 + python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml b/.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml similarity index 65% rename from .github/workflows/generated-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml rename to .github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml index da2bbc1400388..620e4c3d2d318 100644 --- a/.github/workflows/generated-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml +++ b/.github/workflows/generated-periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7.yml @@ -1,20 +1,20 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/linux_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: libtorch-linux-xenial-cuda11.1-py3.6-gcc7 +name: periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 on: - # TODO: Enable pull_request builds when we can verify capacity can be met by auto-scalers - push: - branches: - - master - - release/* + pull_request: + types: [unassigned] + schedule: + - cron: 45 0,4,8,12,16,20 * * * workflow_dispatch: env: - BUILD_ENVIRONMENT: libtorch-linux-xenial-cuda11.1-py3.6-gcc7 + BUILD_ENVIRONMENT: periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -22,27 +22,54 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: libtorch-linux-xenial-cuda11.1-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/libtorch') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') || contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge + needs: [ciflow_should_run] env: DOCKER_BUILDKIT: 1 timeout-minutes: 90 outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -50,11 +77,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -104,17 +139,32 @@ jobs: build: runs-on: linux.2xlarge - needs: [calculate-docker-image, ] + needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: libtorch-linux-xenial-cuda11.1-py3.6-gcc7-build + JOB_BASE_NAME: periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -126,17 +176,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -144,9 +195,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -184,6 +237,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | diff --git a/.github/workflows/generated-periodic-linux-xenial-cuda11.3-py3.6-gcc7.yml b/.github/workflows/generated-periodic-linux-xenial-cuda11.1-py3.6-gcc7.yml similarity index 73% rename from .github/workflows/generated-periodic-linux-xenial-cuda11.3-py3.6-gcc7.yml rename to .github/workflows/generated-periodic-linux-xenial-cuda11.1-py3.6-gcc7.yml index 5ec1ddb8516eb..e318e665c9156 100644 --- a/.github/workflows/generated-periodic-linux-xenial-cuda11.3-py3.6-gcc7.yml +++ b/.github/workflows/generated-periodic-linux-xenial-cuda11.1-py3.6-gcc7.yml @@ -1,7 +1,7 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/linux_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: periodic-linux-xenial-cuda11.3-py3.6-gcc7 +name: periodic-linux-xenial-cuda11.1-py3.6-gcc7 on: pull_request: @@ -11,9 +11,10 @@ on: workflow_dispatch: env: - BUILD_ENVIRONMENT: periodic-linux-xenial-cuda11.3-py3.6-gcc7 - DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 + BUILD_ENVIRONMENT: periodic-linux-xenial-cuda11.1-py3.6-gcc7 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7 SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla TORCH_CUDA_ARCH_LIST: 5.2 IN_CI: 1 # This is used for the phase of adding wheel tests only, will be removed once completed @@ -21,18 +22,23 @@ env: # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} concurrency: - group: periodic-linux-xenial-cuda11.3-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }} + group: periodic-linux-xenial-cuda11.1-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') || contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} steps: - name: noop run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" calculate-docker-image: if: ${{ github.repository_owner == 'pytorch' }} runs-on: linux.2xlarge @@ -43,12 +49,27 @@ jobs: outputs: docker_image: ${{ steps.calculate-tag.outputs.docker_image }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -56,11 +77,19 @@ jobs: run: | rm -rf "${GITHUB_WORKSPACE:?}/*" rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: # deep clone, to allow use of git merge-base fetch-depth: 0 + submodules: false - name: Calculate docker image tag id: calculate-tag run: | @@ -113,14 +142,29 @@ jobs: needs: [calculate-docker-image, ciflow_should_run] env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: periodic-linux-xenial-cuda11.3-py3.6-gcc7-build + JOB_BASE_NAME: periodic-linux-xenial-cuda11.1-py3.6-gcc7-build steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . @@ -132,17 +176,18 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: - fetch-depth: 0 # deep clone, to allow sharding to use git rev-list + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | docker pull "${DOCKER_IMAGE}" - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Build PyTorch run: | docker run \ @@ -150,9 +195,11 @@ jobs: -e JOB_BASE_NAME \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ @@ -188,19 +235,8 @@ jobs: docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Archive artifacts into zip run: | - zip -r artifacts.zip dist/ build/ .pytorch-test-times.json - # Upload to github so that people can click and download artifacts - - uses: actions/upload-artifact@v2 - # Don't fail on upload to GH since it's only for user convenience - continue-on-error: true - name: Store PyTorch Build Artifacts on Github - with: - name: ${{ env.BUILD_ENVIRONMENT }} - retention-days: 14 - if-no-files-found: error - path: - artifacts.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Build Artifacts on S3 with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -212,6 +248,24 @@ jobs: # Always hold for active ssh sessions if: always() run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh - name: Clean up docker images if: always() run: | @@ -224,11 +278,16 @@ jobs: needs: [ciflow_should_run] env: TEST_RUNNER_TYPE: linux.8xlarge.nvidia.gpu + ENABLE_DISTRIBUTED_TEST: 1 ENABLE_JIT_LEGACY_TEST: '' ENABLE_MULTIGPU_TEST: '' ENABLE_NOGPU_NO_AVX_TEST: '' ENABLE_NOGPU_NO_AVX2_TEST: '' ENABLE_SLOW_TEST: '' + ENABLE_DOCS_TEST: '' + ENABLE_BACKWARDS_COMPAT_TEST: '' + ENABLE_XLA_TEST: '' + ENABLE_NOARCH_TEST: '' NUM_TEST_SHARDS: 2 MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu NOGPU_RUNNER_TYPE: linux.2xlarge @@ -243,7 +302,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -256,21 +315,37 @@ jobs: runs-on: ${{ matrix.runner }} env: DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} - JOB_BASE_NAME: periodic-linux-xenial-cuda11.3-py3.6-gcc7-test + JOB_BASE_NAME: periodic-linux-xenial-cuda11.1-py3.6-gcc7-test TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Log in to ECR run: | aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh bash /tmp/ecr-login.sh rm /tmp/ecr-login.sh - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" run: | # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - name: Clean workspace run: | rm -rf "${GITHUB_WORKSPACE:?}/*" @@ -279,9 +354,14 @@ jobs: uses: seemethere/add-github-ssh-key@v1 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 submodules: recursive - name: Pull docker image run: | @@ -313,12 +393,9 @@ jobs: - name: Output disk space left run: | sudo df -H - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" - name: Test PyTorch env: - BUILD_ENVIRONMENT: periodic-linux-xenial-cuda11.3-py3.6-gcc7-${{ matrix.config }} + PR_NUMBER: ${{ github.event.pull_request.number }} run: | if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh @@ -334,6 +411,7 @@ jobs: docker run \ ${GPU_FLAG:-} \ -e BUILD_ENVIRONMENT \ + -e PR_NUMBER \ -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ -e GITHUB_ACTIONS \ -e IN_CI \ @@ -343,9 +421,12 @@ jobs: -e TEST_CONFIG \ -e NUM_TEST_SHARDS \ -e PYTORCH_IGNORE_DISABLED_ISSUES \ + -e PR_LABELS \ + -e CONTINUE_THROUGH_ERROR \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ @@ -379,7 +460,7 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + - uses: seemethere/upload-artifact-s3@v3 name: Store PyTorch Test Reports on S3 if: always() with: @@ -388,66 +469,6 @@ jobs: if-no-files-found: error path: test-reports-*.zip - - name: Hold runner for 2 hours or until ssh sessions have drained - # Always hold for active ssh sessions - if: always() - run: .github/scripts/wait_for_ssh_to_drain.sh - - name: Clean up docker images - if: always() - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - # Prune all of the docker images - docker system prune -af - - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ciflow_should_run] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - # Should preserve paths so reports should still be in test/test-reports - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) - run: | - python3 tools/render_junit.py test - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -457,10 +478,31 @@ jobs: env: AWS_DEFAULT_REGION: us-east-1 CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: periodic-linux-xenial-cuda11.3-py3.6-gcc7-test + JOB_BASE_NAME: periodic-linux-xenial-cuda11.1-py3.6-gcc7-test CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . diff --git a/.github/workflows/generated-win-vs2019-cuda11.1-py3.yml b/.github/workflows/generated-periodic-win-vs2019-cuda11.1-py3.yml similarity index 66% rename from .github/workflows/generated-win-vs2019-cuda11.1-py3.yml rename to .github/workflows/generated-periodic-win-vs2019-cuda11.1-py3.yml index 9c9b733aef445..360fdc38c86ad 100644 --- a/.github/workflows/generated-win-vs2019-cuda11.1-py3.yml +++ b/.github/workflows/generated-periodic-win-vs2019-cuda11.1-py3.yml @@ -1,22 +1,23 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/windows_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: win-vs2019-cuda11.1-py3 +name: periodic-win-vs2019-cuda11.1-py3 on: - push: - branches: - - master - - release/* + pull_request: + types: [unassigned] + schedule: + - cron: 45 0,4,8,12,16,20 * * * workflow_dispatch: env: - BUILD_ENVIRONMENT: win-vs2019-cuda11.1-py3 + BUILD_ENVIRONMENT: periodic-win-vs2019-cuda11.1-py3 BUILD_WHEEL: 1 CUDA_VERSION: "11.1" IN_CI: 1 INSTALL_WINDOWS_SDK: 1 PYTHON_VERSION: "3.8" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} SCCACHE_BUCKET: "ossci-compiler-cache" VC_PRODUCT: "BuildTools" VC_VERSION: "" @@ -28,28 +29,52 @@ env: USE_CUDA: 1 concurrency: - group: win-vs2019-cuda11.1-py3-${{ github.event.pull_request.number || github.sha }} + group: periodic-win-vs2019-cuda11.1-py3-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled') || contains(github.event.pull_request.labels.*.name, 'ciflow/win')) }} + steps: + - name: noop + run: echo running ciflow_should_run build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: "windows.4xlarge" defaults: run: working-directory: pytorch-${{ github.run_id }} + needs: [ciflow_should_run] env: - JOB_BASE_NAME: win-vs2019-cuda11.1-py3-build + JOB_BASE_NAME: periodic-win-vs2019-cuda11.1-py3-build http_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" steps: + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -81,12 +106,23 @@ jobs: path: C:\${{ github.run_id }}\build-results - name: Upload artifacts to s3 if: always() - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + uses: seemethere/upload-artifact-s3@v3 with: retention-days: 14 if-no-files-found: error name: ${{ env.BUILD_ENVIRONMENT }} path: C:\${{ github.run_id }}\build-results + - name: Wait until all sessions have drained + shell: powershell + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 - name: Cleanup build-results and workspaces if: always() shell: bash @@ -99,6 +135,7 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} + needs: [ciflow_should_run] runs-on: ubuntu-18.04 env: TEST_RUNNER_TYPE: windows.8xlarge.nvidia.gpu @@ -115,14 +152,14 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py test: env: - JOB_BASE_NAME: win-vs2019-cuda11.1-py3-test + JOB_BASE_NAME: periodic-win-vs2019-cuda11.1-py3-test SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} TEST_CONFIG: ${{ matrix.config }} @@ -130,7 +167,8 @@ jobs: https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" RUN_SMOKE_TESTS_ONLY_ON_PR: False PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} - needs: [build, generate-test-matrix, ] + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} + needs: [build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false @@ -140,12 +178,29 @@ jobs: working-directory: pytorch-${{ github.run_id }} steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -202,60 +257,17 @@ jobs: if-no-files-found: error path: pytorch-${{ github.run_id }}/test-reports-*.zip - - name: Cleanup workspace + - name: Wait until all sessions have drained + shell: powershell if: always() - shell: bash - # Should remove the entirety of pytorch-${{ github.run_id }} + timeout-minutes: 120 run: | - rm -rf ./* - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - # TODO: Make this into a composite step - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() run: | - python3 tools/render_junit.py test + .github\scripts\kill_active_ssh_sessions.ps1 - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -265,10 +277,19 @@ jobs: env: AWS_DEFAULT_REGION: us-east-1 CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: win-vs2019-cuda11.1-py3-test + JOB_BASE_NAME: periodic-win-vs2019-cuda11.1-py3-test CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Cleanup workspace + if: always() + shell: bash + # Should remove the entirety of pytorch-${{ github.run_id }} + run: | + rm -rf ./* diff --git a/.github/workflows/generated-puretorch-linux-xenial-py3.6-gcc5.4.yml b/.github/workflows/generated-puretorch-linux-xenial-py3.6-gcc5.4.yml new file mode 100644 index 0000000000000..af1228903b1f5 --- /dev/null +++ b/.github/workflows/generated-puretorch-linux-xenial-py3.6-gcc5.4.yml @@ -0,0 +1,275 @@ +# @generated DO NOT EDIT MANUALLY +# Template is at: .github/templates/linux_ci_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: puretorch-linux-xenial-py3.6-gcc5.4 + +on: + pull_request: + types: [unassigned] + push: + branches: + - master + - release/* + workflow_dispatch: + +env: + BUILD_ENVIRONMENT: puretorch-linux-xenial-py3.6-gcc5.4 + DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4 + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + TORCH_CUDA_ARCH_LIST: 5.2 + IN_CI: 1 + # This is used for the phase of adding wheel tests only, will be removed once completed + IN_WHEEL_TEST: 1 + # Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh + CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + +concurrency: + group: puretorch-linux-xenial-py3.6-gcc5.4-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) }} + env: + LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + steps: + - name: noop + run: echo running ciflow_should_run + - name: print labels + run: echo "${LABELS}" + calculate-docker-image: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: linux.2xlarge + needs: [ciflow_should_run] + env: + DOCKER_BUILDKIT: 1 + timeout-minutes: 90 + outputs: + docker_image: ${{ steps.calculate-tag.outputs.docker_image }} + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: false + - name: Calculate docker image tag + id: calculate-tag + run: | + DOCKER_TAG=$(git rev-parse HEAD:.circleci/docker) + echo "::set-output name=docker_tag::${DOCKER_TAG}" + echo "::set-output name=docker_image::${DOCKER_IMAGE_BASE}:${DOCKER_TAG}" + - name: Check if image should be built + id: check + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + BASE_REVISION: ${{ github.event.pull_request.base.sha || github.sha }} + run: | + set -x + # Check if image already exists, if it does then skip building it + if docker manifest inspect "${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"; then + exit 0 + fi + if [[ "$BASE_REVISION" = "$(git rev-parse HEAD)" ]]; then + # if we're on the base branch then use the parent commit + MERGE_BASE=$(git rev-parse HEAD~) + else + # otherwise we're on a PR, so use the most recent base commit + MERGE_BASE=$(git merge-base HEAD "$BASE_REVISION") + fi + # Covers the case where a previous tag doesn't exist for the tree + # this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly + if ! git rev-parse "$MERGE_BASE:.circleci/docker"; then + echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit" + exit 1 + fi + PREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker") + # If no image exists but the hash is the same as the previous hash then we should error out here + if [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then + echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch" + echo " contact the PyTorch team to restore the original images" + exit 1 + fi + echo ::set-output name=rebuild::yes + - name: Build and push docker image + if: ${{ steps.check.outputs.rebuild }} + env: + DOCKER_TAG: ${{ steps.calculate-tag.outputs.docker_tag }} + DOCKER_SKIP_S3_UPLOAD: 1 + run: | + export IMAGE_NAME=${DOCKER_IMAGE_BASE#308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/} + cd .circleci/docker && ./build_docker.sh + + build: + runs-on: linux.2xlarge + needs: [calculate-docker-image, ciflow_should_run] + env: + DOCKER_IMAGE: ${{ needs.calculate-docker-image.outputs.docker_image }} + JOB_BASE_NAME: puretorch-linux-xenial-py3.6-gcc5.4-build + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: Log in to ECR + run: | + aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh + bash /tmp/ecr-login.sh + rm /tmp/ecr-login.sh + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Clean workspace + run: | + rm -rf "${GITHUB_WORKSPACE:?}/*" + rm -f ~/.ssh/authorized_keys + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Preserve github env variables for use in docker + run: | + env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}" + - name: Checkout PyTorch + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 + with: + # deep clone, to allow use of git merge-base + fetch-depth: 0 + submodules: recursive + - name: Pull docker image + run: | + docker pull "${DOCKER_IMAGE}" + - name: Build PyTorch + run: | + docker run \ + -e BUILD_ENVIRONMENT \ + -e JOB_BASE_NAME \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e SCCACHE_BUCKET \ + -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + -e CUSTOM_TEST_ARTIFACT_BUILD_DIR \ + -e SKIP_SCCACHE_INITIALIZATION=1 \ + -e TORCH_CUDA_ARCH_LIST \ + -e PR_LABELS \ + -e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --tty \ + --user jenkins \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" \ + sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh' + - name: Parse ref + id: parse-ref + run: .github/scripts/parse_ref.py + - name: Display and upload binary build size statistics (Click Me) + # temporary hack: set CIRCLE_* vars, until we update + # tools/stats/print_test_stats.py to natively support GitHub Actions + env: + AWS_DEFAULT_REGION: us-east-1 + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} + CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} + CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} + CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + run: | + COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0) + export COMMIT_TIME + pip3 install requests + python3 -m tools.stats.upload_binary_size_to_scuba || exit 0 + - name: Chown workspace + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Archive artifacts into zip + run: | + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json + - uses: seemethere/upload-artifact-s3@v3 + name: Store PyTorch Build Artifacts on S3 + with: + name: ${{ env.BUILD_ENVIRONMENT }} + retention-days: 14 + if-no-files-found: error + path: + artifacts.zip + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Kill containers, clean up images + if: always() + run: | + # ignore expansion of "docker ps -q" since it could be empty + # shellcheck disable=SC2046 + docker stop $(docker ps -q) || true + # Prune all of the docker images + docker system prune -af + - name: Chown workspace + env: + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + run: | + # Ensure the working directory gets chowned back to the current user + docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . + - name: Hold runner for 2 hours or until ssh sessions have drained + # Always hold for active ssh sessions + if: always() + run: .github/scripts/wait_for_ssh_to_drain.sh + - name: Clean up docker images + if: always() + run: | + # Prune all of the docker images + docker system prune -af diff --git a/.github/workflows/generated-win-vs2019-cpu-py3.yml b/.github/workflows/generated-win-vs2019-cpu-py3.yml index 2769f7c498eef..1277a69f1d13d 100644 --- a/.github/workflows/generated-win-vs2019-cpu-py3.yml +++ b/.github/workflows/generated-win-vs2019-cpu-py3.yml @@ -5,6 +5,7 @@ name: win-vs2019-cpu-py3 on: pull_request: + types: [opened, synchronize, reopened, unassigned] push: branches: - master @@ -18,6 +19,7 @@ env: IN_CI: 1 INSTALL_WINDOWS_SDK: 1 PYTHON_VERSION: "3.8" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} SCCACHE_BUCKET: "ossci-compiler-cache" VC_PRODUCT: "BuildTools" VC_VERSION: "" @@ -27,28 +29,52 @@ env: no_proxy: localhost,127.0.0.1,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock concurrency: - group: win-vs2019-cpu-py3-${{ github.event.pull_request.number || github.sha }} + group: win-vs2019-cpu-py3-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/win')) }} + steps: + - name: noop + run: echo running ciflow_should_run build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: "windows.4xlarge" defaults: run: working-directory: pytorch-${{ github.run_id }} + needs: [ciflow_should_run] env: JOB_BASE_NAME: win-vs2019-cpu-py3-build http_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" steps: + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -72,12 +98,23 @@ jobs: path: C:\${{ github.run_id }}\build-results - name: Upload artifacts to s3 if: always() - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + uses: seemethere/upload-artifact-s3@v3 with: retention-days: 14 if-no-files-found: error name: ${{ env.BUILD_ENVIRONMENT }} path: C:\${{ github.run_id }}\build-results + - name: Wait until all sessions have drained + shell: powershell + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 - name: Cleanup build-results and workspaces if: always() shell: bash @@ -90,6 +127,7 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} + needs: [ciflow_should_run] runs-on: ubuntu-18.04 env: TEST_RUNNER_TYPE: windows.4xlarge @@ -106,7 +144,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -121,7 +159,8 @@ jobs: https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" RUN_SMOKE_TESTS_ONLY_ON_PR: False PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} - needs: [build, generate-test-matrix, ] + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} + needs: [build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false @@ -131,12 +170,29 @@ jobs: working-directory: pytorch-${{ github.run_id }} steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -185,60 +241,17 @@ jobs: if-no-files-found: error path: pytorch-${{ github.run_id }}/test-reports-*.zip - - name: Cleanup workspace + - name: Wait until all sessions have drained + shell: powershell if: always() - shell: bash - # Should remove the entirety of pytorch-${{ github.run_id }} + timeout-minutes: 120 run: | - rm -rf ./* - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - # TODO: Make this into a composite step - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() run: | - python3 tools/render_junit.py test + .github\scripts\kill_active_ssh_sessions.ps1 - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -253,5 +266,14 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Cleanup workspace + if: always() + shell: bash + # Should remove the entirety of pytorch-${{ github.run_id }} + run: | + rm -rf ./* diff --git a/.github/workflows/generated-win-vs2019-cuda10.1-py3.yml b/.github/workflows/generated-win-vs2019-cuda10.1-py3.yml index d94ba7850ee32..185cb5903e189 100644 --- a/.github/workflows/generated-win-vs2019-cuda10.1-py3.yml +++ b/.github/workflows/generated-win-vs2019-cuda10.1-py3.yml @@ -5,6 +5,7 @@ name: win-vs2019-cuda10.1-py3 on: pull_request: + types: [opened, synchronize, reopened, unassigned] push: branches: - master @@ -18,6 +19,7 @@ env: IN_CI: 1 INSTALL_WINDOWS_SDK: 1 PYTHON_VERSION: "3.8" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} SCCACHE_BUCKET: "ossci-compiler-cache" VC_PRODUCT: "BuildTools" VC_VERSION: "" @@ -29,28 +31,52 @@ env: USE_CUDA: 1 concurrency: - group: win-vs2019-cuda10.1-py3-${{ github.event.pull_request.number || github.sha }} + group: win-vs2019-cuda10.1-py3-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: + ciflow_should_run: + runs-on: ubuntu-18.04 + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/win')) }} + steps: + - name: noop + run: echo running ciflow_should_run build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: "windows.4xlarge" defaults: run: working-directory: pytorch-${{ github.run_id }} + needs: [ciflow_should_run] env: JOB_BASE_NAME: win-vs2019-cuda10.1-py3-build http_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" steps: + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -82,12 +108,23 @@ jobs: path: C:\${{ github.run_id }}\build-results - name: Upload artifacts to s3 if: always() - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + uses: seemethere/upload-artifact-s3@v3 with: retention-days: 14 if-no-files-found: error name: ${{ env.BUILD_ENVIRONMENT }} path: C:\${{ github.run_id }}\build-results + - name: Wait until all sessions have drained + shell: powershell + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 - name: Cleanup build-results and workspaces if: always() shell: bash @@ -100,6 +137,7 @@ jobs: generate-test-matrix: if: ${{ github.repository_owner == 'pytorch' }} + needs: [ciflow_should_run] runs-on: ubuntu-18.04 env: TEST_RUNNER_TYPE: windows.8xlarge.nvidia.gpu @@ -116,7 +154,7 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py @@ -131,7 +169,8 @@ jobs: https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" RUN_SMOKE_TESTS_ONLY_ON_PR: True PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} - needs: [build, generate-test-matrix, ] + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} + needs: [build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} fail-fast: false @@ -141,12 +180,29 @@ jobs: working-directory: pytorch-${{ github.run_id }} steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -203,60 +259,17 @@ jobs: if-no-files-found: error path: pytorch-${{ github.run_id }}/test-reports-*.zip - - name: Cleanup workspace + - name: Wait until all sessions have drained + shell: powershell if: always() - shell: bash - # Should remove the entirety of pytorch-${{ github.run_id }} + timeout-minutes: 120 run: | - rm -rf ./* - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - # TODO: Make this into a composite step - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() run: | - python3 tools/render_junit.py test + .github\scripts\kill_active_ssh_sessions.ps1 - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -271,5 +284,14 @@ jobs: CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Cleanup workspace + if: always() + shell: bash + # Should remove the entirety of pytorch-${{ github.run_id }} + run: | + rm -rf ./* diff --git a/.github/workflows/generated-periodic-win-vs2019-cuda11.3-py3.yml b/.github/workflows/generated-win-vs2019-cuda11.3-py3.yml similarity index 69% rename from .github/workflows/generated-periodic-win-vs2019-cuda11.3-py3.yml rename to .github/workflows/generated-win-vs2019-cuda11.3-py3.yml index 78c536c0bbd11..b339e79926f53 100644 --- a/.github/workflows/generated-periodic-win-vs2019-cuda11.3-py3.yml +++ b/.github/workflows/generated-win-vs2019-cuda11.3-py3.yml @@ -1,22 +1,25 @@ # @generated DO NOT EDIT MANUALLY # Template is at: .github/templates/windows_ci_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: periodic-win-vs2019-cuda11.3-py3 +name: win-vs2019-cuda11.3-py3 on: pull_request: types: [unassigned] - schedule: - - cron: 45 0,4,8,12,16,20 * * * + push: + branches: + - master + - release/* workflow_dispatch: env: - BUILD_ENVIRONMENT: periodic-win-vs2019-cuda11.3-py3 + BUILD_ENVIRONMENT: win-vs2019-cuda11.3-py3 BUILD_WHEEL: 1 CUDA_VERSION: "11.3" IN_CI: 1 INSTALL_WINDOWS_SDK: 1 PYTHON_VERSION: "3.8" + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} SCCACHE_BUCKET: "ossci-compiler-cache" VC_PRODUCT: "BuildTools" VC_VERSION: "" @@ -28,13 +31,13 @@ env: USE_CUDA: 1 concurrency: - group: periodic-win-vs2019-cuda11.3-py3-${{ github.event.pull_request.number || github.sha }} + group: win-vs2019-cuda11.3-py3-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true jobs: ciflow_should_run: runs-on: ubuntu-18.04 - if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (github.event.action == 'unassigned' && contains(github.event.pull_request.labels.*.name, 'ciflow/scheduled')) }} + if: ${{ (github.event_name != 'pull_request') || (github.event.action !='unassigned') || (contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cuda') || contains(github.event.pull_request.labels.*.name, 'ciflow/win')) }} steps: - name: noop run: echo running ciflow_should_run @@ -46,17 +49,34 @@ jobs: working-directory: pytorch-${{ github.run_id }} needs: [ciflow_should_run] env: - JOB_BASE_NAME: periodic-win-vs2019-cuda11.3-py3-build + JOB_BASE_NAME: win-vs2019-cuda11.3-py3-build http_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" steps: + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -88,12 +108,23 @@ jobs: path: C:\${{ github.run_id }}\build-results - name: Upload artifacts to s3 if: always() - uses: seemethere/upload-artifact-s3@9d7ceb0ab39c2c88d93ef7792b27425b27d59162 + uses: seemethere/upload-artifact-s3@v3 with: retention-days: 14 if-no-files-found: error name: ${{ env.BUILD_ENVIRONMENT }} path: C:\${{ github.run_id }}\build-results + - name: Wait until all sessions have drained + shell: powershell + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 - name: Cleanup build-results and workspaces if: always() shell: bash @@ -123,14 +154,14 @@ jobs: - name: Install dependencies run: pip install typing-extensions - name: Clone pytorch/pytorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Generating test matrix id: set-matrix run: .github/scripts/generate_pytorch_test_matrix.py test: env: - JOB_BASE_NAME: periodic-win-vs2019-cuda11.3-py3-test + JOB_BASE_NAME: win-vs2019-cuda11.3-py3-test SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} TEST_CONFIG: ${{ matrix.config }} @@ -138,6 +169,7 @@ jobs: https_proxy: "http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" RUN_SMOKE_TESTS_ONLY_ON_PR: False PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ github.repository_owner == 'pytorch' && (github.event_name == 'push' || github.event_name == 'schedule') }} needs: [build, generate-test-matrix, ciflow_should_run] strategy: matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }} @@ -148,12 +180,29 @@ jobs: working-directory: pytorch-${{ github.run_id }} steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: submodules: recursive path: pytorch-${{ github.run_id }} # deep clone, to allow use of git merge-base fetch-depth: 0 + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: seemethere/add-github-ssh-key@v1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Install Visual Studio 2019 toolchain shell: powershell run: | @@ -210,60 +259,17 @@ jobs: if-no-files-found: error path: pytorch-${{ github.run_id }}/test-reports-*.zip - - name: Cleanup workspace + - name: Wait until all sessions have drained + shell: powershell if: always() - shell: bash - # Should remove the entirety of pytorch-${{ github.run_id }} + timeout-minutes: 120 run: | - rm -rf ./* - - # this is a separate step from test because the log files from test are too - # long: basically, GitHub tries to render all of the log files when you click - # through an action causing extreme slowdown on actions that contain too many - # logs (like test); we can always move it back to the other one, but it - # doesn't create the best experience - render_test_results: - needs: [generate-test-matrix, test, ciflow_should_run] - if: ${{ needs.test.result != 'skipped' || failure() }} - runs-on: linux.2xlarge - strategy: - matrix: ${{ fromJson(needs.generate-test-matrix.outputs.render-matrix) }} - fail-fast: false - # TODO: Make this into a composite step - steps: - - name: Log in to ECR - run: | - aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh - bash /tmp/ecr-login.sh - rm /tmp/ecr-login.sh - - name: Chown workspace - run: | - # Ensure the working directory gets chowned back to the current user - docker run --rm -v "$(pwd)/../":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - name: Clean workspace - run: | - rm -rf "${GITHUB_WORKSPACE:?}/*" - - name: Checkout PyTorch - uses: actions/checkout@v2 - with: - # deep clone, to allow tools/stats/print_test_stats.py to use Git commands - fetch-depth: 0 - - uses: actions/download-artifact@v2 - name: Download PyTorch Test Reports - with: - name: test-reports-${{ matrix.config }} - path: . - - name: Unzip test reports - run: | - unzip -o 'test-reports-*.zip' - - name: Install dependencies - # boto3 version copied from .circleci/docker/common/install_conda.sh - run: | - pip3 install -r requirements.txt - pip3 install boto3==1.16.34 junitparser rich - - name: Output Test Results (Click Me) + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + if: always() run: | - python3 tools/render_junit.py test + .github\scripts\kill_active_ssh_sessions.ps1 - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -273,10 +279,19 @@ jobs: env: AWS_DEFAULT_REGION: us-east-1 CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }} - JOB_BASE_NAME: periodic-win-vs2019-cuda11.3-py3-test + JOB_BASE_NAME: win-vs2019-cuda11.3-py3-test CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }} CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }} CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }} CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}' + shell: bash run: | + python3 -m pip install -r requirements.txt + python3 -m pip install boto3==1.16.34 python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test + - name: Cleanup workspace + if: always() + shell: bash + # Should remove the entirety of pytorch-${{ github.run_id }} + run: | + rm -rf ./* diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 720e76c4e6a5f..a1b6182aedaf4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ jobs: python-version: 3.x architecture: x64 - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Install requirements id: requirements run: pip3 install -r requirements.txt --user @@ -101,7 +101,7 @@ jobs: python-version: 3.x architecture: x64 - name: Fetch PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: fetch-depth: 0 # deep clone, to allow us to use git merge-base - name: Run clang-format @@ -140,7 +140,7 @@ jobs: python-version: 2.x architecture: x64 - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Attempt to run setup.py run: | if ! python2 setup.py | grep -q "Python 2 has reached end-of-life and is no longer supported by PyTorch."; then @@ -159,7 +159,7 @@ jobs: python-version: 3.x architecture: x64 - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Install requirements id: requirements run: | @@ -168,7 +168,7 @@ jobs: run: | pip3 install Jinja2==3.0.1 --user - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Regenerate workflows id: generate_workflows run: .github/scripts/generate_ci_workflows.py @@ -238,7 +238,7 @@ jobs: - name: Setup Node uses: actions/setup-node@v2 - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Install markdown-toc run: npm install -g markdown-toc - name: Regenerate ToCs and check that they didn't change @@ -274,7 +274,7 @@ jobs: python-version: 3.x architecture: x64 - name: Fetch PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: fetch-depth: 2 # to allow us to use github.event.pull_request.head.sha - name: Prepare output dir with HEAD commit SHA @@ -326,7 +326,7 @@ jobs: image: ghcr.io/pytorch/cilint-clang-tidy:d8f0c777964d0dd8a147360de80aed1a13eb613a steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: fetch-depth: 0 # to allow tools/linter/clang_tidy.py to do its thing - name: Prepare output dir with HEAD commit SHA @@ -367,7 +367,12 @@ jobs: cd "${GITHUB_WORKSPACE}" python3 -m tools.linter.clang_tidy \ - --paths torch/csrc/fx \ + --paths \ + torch/csrc/fx \ + torch/csrc/utils \ + torch/csrc/generic \ + torch/csrc/deploy \ + torch/csrc/tensor \ --clang-tidy-exe "$(which clang-tidy)" \ --disable-progress-bar 2>&1 | tee "${GITHUB_WORKSPACE}"/clang-tidy-output.txt @@ -407,7 +412,7 @@ jobs: python-version: 3.x architecture: x64 - name: Fetch PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Install dependencies run: | set -eux @@ -429,7 +434,7 @@ jobs: python-version: 3.8 architecture: x64 - name: Fetch PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - name: Install dependencies run: | set -eux @@ -462,5 +467,5 @@ jobs: fi concurrency: - group: lint-${{ github.event.pull_request.number || github.sha }} + group: lint-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/push_nightly_docker_ghcr.yml b/.github/workflows/push_nightly_docker_ghcr.yml index 311aa94601d6a..892cb5c17aa86 100644 --- a/.github/workflows/push_nightly_docker_ghcr.yml +++ b/.github/workflows/push_nightly_docker_ghcr.yml @@ -14,7 +14,7 @@ jobs: GHCR_PAT: ${{ secrets.GHCR_PAT }} steps: - name: Checkout - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: ref: master - name: Build and upload nightly docker diff --git a/.github/workflows/run_torchbench.yml b/.github/workflows/run_torchbench.yml index 0ae189e99f06a..cee27e1866282 100644 --- a/.github/workflows/run_torchbench.yml +++ b/.github/workflows/run_torchbench.yml @@ -18,11 +18,11 @@ jobs: timeout-minutes: 720 steps: - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: path: pytorch - name: Checkout TorchBench - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: repository: pytorch/benchmark path: benchmark @@ -66,5 +66,5 @@ jobs: path: ~/.torchbench/bisection/pr${{ github.event.number }} concurrency: - group: run-torchbench-${{ github.event.pull_request.number || github.sha }} + group: run-torchbench-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/test_tools.yml b/.github/workflows/test_tools.yml index 19a0fd9d4e7e7..02ae0dd34e4fd 100644 --- a/.github/workflows/test_tools.yml +++ b/.github/workflows/test_tools.yml @@ -16,7 +16,7 @@ jobs: python-version: 3.x architecture: x64 - name: Checkout PyTorch - uses: actions/checkout@v2 + uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: fetch-depth: 0 # deep clone, to allow us to use git log - name: Install dependencies @@ -31,5 +31,5 @@ jobs: run: python -m unittest discover -vs tools/test -p 'test_*.py' concurrency: - group: test-tools-${{ github.event.pull_request.number || github.sha }} + group: test-tools-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.gitmodules b/.gitmodules index 6836ccb49c881..a7cc437f43840 100644 --- a/.gitmodules +++ b/.gitmodules @@ -139,3 +139,6 @@ [submodule "third_party/pocketfft"] path = third_party/pocketfft url = https://github.com/mreineck/pocketfft +[submodule "third_party/breakpad"] + path = third_party/breakpad + url = https://github.com/driazati/breakpad.git diff --git a/.jenkins/caffe2/common.sh b/.jenkins/caffe2/common.sh index 026cb8349d3d9..168e823ba2cc4 100644 --- a/.jenkins/caffe2/common.sh +++ b/.jenkins/caffe2/common.sh @@ -18,7 +18,7 @@ if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then if which sccache > /dev/null; then # Save sccache logs to file sccache --stop-server || true - rm ~/sccache_error.log || true + rm -f ~/sccache_error.log || true SCCACHE_ERROR_LOG=~/sccache_error.log SCCACHE_IDLE_TIMEOUT=0 sccache --start-server # Report sccache stats for easier debugging diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index e66b7ae958a1e..75e269d6f6909 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -170,7 +170,9 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then - pip install -q --user onnxruntime==1.7.0 + pip install -q --user flatbuffers==2.0 + wget https://ortpypackage.blob.core.windows.net/ort-nightly/ort_nightly-1.8.0.dev202107131-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install -q --user ort_nightly-1.8.0.dev202107131-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl fi "$ROOT_DIR/scripts/onnx/test.sh" fi diff --git a/.jenkins/pytorch/build-asan.sh b/.jenkins/pytorch/build-asan.sh index 37dfeebdbd332..8d0bcd2555342 100755 --- a/.jenkins/pytorch/build-asan.sh +++ b/.jenkins/pytorch/build-asan.sh @@ -16,6 +16,9 @@ clang --version # detect_leaks=0: Python is very leaky, so we need suppress it # symbolize=1: Gives us much better errors when things go wrong export ASAN_OPTIONS=detect_leaks=0:symbolize=1:detect_odr_violation=0 +if [ -n "$(which conda)" ]; then + export CMAKE_PREFIX_PATH=/opt/conda +fi # FIXME: Remove the hardcoded "-pthread" option. # With asan build, the cmake thread CMAKE_HAVE_LIBC_CREATE[1] checking will diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index f6ac52aed99c4..226b8521ee049 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -59,7 +59,7 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then export BUILD_SPLIT_CUDA=ON fi -if [[ ${BUILD_ENVIRONMENT} == *"pure_torch"* ]]; then +if [[ ${BUILD_ENVIRONMENT} == *"pure_torch"* || ${BUILD_ENVIRONMENT} == *"puretorch"* ]]; then export BUILD_CAFFE2=OFF fi @@ -88,6 +88,8 @@ if ! which conda; then else export USE_MKLDNN=0 fi +else + export CMAKE_PREFIX_PATH=/opt/conda fi if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then @@ -222,7 +224,11 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then get_bazel + # first build the whole torch for CPU-only tools/bazel build --config=no-tty :torch + # then build selected set of targets with GPU-support. + # TODO: eventually this should converge to building the whole :torch with GPU-support + tools/bazel build --config=no-tty --config=gpu :c10 else # check that setup.py would fail with bad arguments echo "The next three invocations are expected to fail with invalid command error messages." diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index 52b91510c4029..09e814b07d62d 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -74,7 +74,7 @@ if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then if which sccache > /dev/null; then # Save sccache logs to file sccache --stop-server > /dev/null 2>&1 || true - rm ~/sccache_error.log || true + rm -f ~/sccache_error.log || true if [[ -n "${SKIP_SCCACHE_INITIALIZATION:-}" ]]; then # sccache --start-server seems to hang forever on self hosted runners for GHA # so let's just go ahead and skip the --start-server altogether since it seems diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index fd94ce14a1c5f..cb7ef207af47c 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -49,6 +49,17 @@ function get_exit_code() { return $retcode } +function get_pr_change_files() { + # The fetch may fail on Docker hosts, this fetch is necessary for GHA + # accepts PR_NUMBER and extract filename as arguments + set +e + tmp_file=$(mktemp) + wget -O "$tmp_file" "https://api.github.com/repos/pytorch/pytorch/pulls/$1/files" + # this regex extracts the filename list according to the GITHUB REST API result. + sed -n "s/.*\"filename\": \"\(.*\)\",/\1/p" "$tmp_file" | tee "$2" + set -e +} + function file_diff_from_base() { # The fetch may fail on Docker hosts, this fetch is necessary for GHA set +e @@ -59,9 +70,9 @@ function file_diff_from_base() { function get_bazel() { # download bazel version - wget https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-linux-x86_64 -O tools/bazel + wget https://github.com/bazelbuild/bazel/releases/download/4.1.0/bazel-4.1.0-linux-x86_64 -O tools/bazel # verify content - echo '753434f4fa730266cf5ce21d1fdd425e1e167dd9347ad3e8adc19e8c0d54edca tools/bazel' | sha256sum --quiet -c + echo '0eb2e378d2782e7810753e2162245ad1179c1bb12f848c692b4a595b4edf779b tools/bazel' | sha256sum --quiet -c chmod +x tools/bazel } diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index 2b918dad31385..76975310843c4 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -19,7 +19,6 @@ fi python tools/download_mnist.py --quiet -d test/cpp/api/mnist OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api time python test/run_test.py --verbose -i distributed/test_jit_c10d -time python test/run_test.py --verbose -i distributed/test_distributed_fork time python test/run_test.py --verbose -i distributed/test_c10d_common time python test/run_test.py --verbose -i distributed/test_c10d_gloo time python test/run_test.py --verbose -i distributed/test_c10d_nccl diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 93de6fbf68969..9710d3aafb35b 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -19,6 +19,11 @@ BUILD_DIR="build" BUILD_RENAMED_DIR="build_renamed" BUILD_BIN_DIR="$BUILD_DIR"/bin +# GHA has test config defined for the test job, so we need to add them. +if [[ -n "${TEST_CONFIG}" ]]; then + BUILD_ENVIRONMENT="${BUILD_ENVIRONMENT}-${TEST_CONFIG}" +fi + # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" @@ -26,11 +31,7 @@ echo "Testing pytorch" export LANG=C.UTF-8 -# Try to pull value from CIRCLE_PULL_REQUEST first then GITHUB_HEAD_REF second -# CIRCLE_PULL_REQUEST comes from CircleCI -# NOTE: file_diff_from_base is currently bugged for GHA due to an issue finding a merge base for ghstack PRs -# see https://github.com/pytorch/pytorch/issues/60111 -IN_PULL_REQUEST=${CIRCLE_PULL_REQUEST:-} +PR_NUMBER=${PR_NUMBER:-${CIRCLE_PR_NUMBER:-}} if [[ "$BUILD_ENVIRONMENT" == *-slow-* || $TEST_CONFIG == 'slow' ]]; then export PYTORCH_TEST_WITH_SLOW=1 @@ -64,7 +65,7 @@ else export PYTORCH_TEST_SKIP_NOARCH=1 fi -if [[ -n "$IN_PULL_REQUEST" ]] && [[ -z "$CI_MASTER" || "$CI_MASTER" == "false" ]]; then +if [[ -n "$PR_NUMBER" ]] && [[ -z "$CI_MASTER" || "$CI_MASTER" == "false" ]]; then # skip expensive checks when on PR and CI_MASTER flag is not set export PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK=1 else @@ -95,7 +96,7 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then export PYTORCH_TEST_WITH_ASAN=1 export PYTORCH_TEST_WITH_UBSAN=1 # TODO: Figure out how to avoid hard-coding these paths - export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-5.0/bin/llvm-symbolizer + export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-7/bin/llvm-symbolizer export TORCH_USE_RTLD_GLOBAL=1 # NB: We load libtorch.so with RTLD_GLOBAL for UBSAN, unlike our # default behavior. @@ -146,9 +147,14 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX export ATEN_CPU_CAPABILITY=avx2 fi -if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then +if [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then + # if PR_NUMBER exist, use it to grab PR contents. DETERMINE_FROM=$(mktemp) - file_diff_from_base "$DETERMINE_FROM" + if [ -n "$PR_NUMBER" ]; then + get_pr_change_files "$PR_NUMBER" "$DETERMINE_FROM" + else + file_diff_from_base "$DETERMINE_FROM" + fi fi test_python_legacy_jit() { @@ -157,17 +163,17 @@ test_python_legacy_jit() { } test_python_shard1() { - time python test/run_test.py --exclude-jit-executor --shard 1 2 --verbose --determine-from="$DETERMINE_FROM" + time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --shard 1 2 --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } test_python_shard2() { - time python test/run_test.py --exclude-jit-executor --shard 2 2 --verbose --determine-from="$DETERMINE_FROM" + time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --shard 2 2 --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } test_python() { - time python test/run_test.py --exclude-jit-executor --verbose --determine-from="$DETERMINE_FROM" + time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --verbose --determine-from="$DETERMINE_FROM" assert_git_not_dirty } @@ -247,6 +253,7 @@ test_libtorch() { ln -sf "$TORCH_LIB_DIR"/libbackend_with_compiler.so "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libjitbackend_test.so "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" + ln -sf "$TORCH_LIB_DIR"/libshm* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libtbb* "$TORCH_BIN_DIR" @@ -269,7 +276,8 @@ test_libtorch() { python test/cpp/jit/tests_setup.py shutdown # Wait for background download to finish wait - OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$TORCH_BIN_DIR"/test_api --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml + # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy. + OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml "$TORCH_BIN_DIR"/test_mobile_nnc --gtest_output=xml:$TEST_REPORTS_DIR/test_mobile_nnc.xml if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3* ]]; then @@ -296,6 +304,10 @@ test_vulkan() { } test_distributed() { + echo "Testing distributed python tests" + time python test/run_test.py --distributed-tests --verbose --determine-from="$DETERMINE_FROM" + assert_git_not_dirty + if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then echo "Testing distributed C++ tests" ln -sf "$TORCH_LIB_DIR"/libtorch* "$TORCH_BIN_DIR" @@ -478,9 +490,14 @@ test_torch_deploy() { ln -sf "$TORCH_LIB_DIR"/libshm* "$TORCH_BIN_DIR" ln -sf "$TORCH_LIB_DIR"/libc10* "$TORCH_BIN_DIR" "$TORCH_BIN_DIR"/test_deploy + "$TORCH_BIN_DIR"/test_api --gtest_filter='IMethodTest.*' assert_git_not_dirty } +test_docs_test() { + .jenkins/pytorch/docs-test.sh +} + if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") @@ -516,6 +533,11 @@ elif [[ "${BUILD_ENVIRONMENT}" == *vulkan-linux* ]]; then test_vulkan elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then test_bazel +elif [[ "${BUILD_ENVIRONMENT}" == *distributed* ]]; then + test_distributed + test_rpc +elif [[ "${TEST_CONFIG}" = docs_test ]]; then + test_docs_test else install_torchvision install_monkeytype @@ -526,9 +548,7 @@ else test_custom_script_ops test_custom_backend test_torch_function_benchmark - test_distributed test_benchmarks - test_rpc if [[ "${BUILD_ENVIRONMENT}" == *linux-xenial-py3.6-gcc7-test* || "${BUILD_ENVIRONMENT}" == *linux-xenial-py3.6-gcc5.4-test* ]]; then test_python_gloo_with_tls fi diff --git a/BUILD.bazel b/BUILD.bazel index ca8874d64e857..36b29379a5c2f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -224,7 +224,9 @@ libtorch_python_generated_sources = [ "torch/csrc/autograd/generated/python_functions_3.cpp", "torch/csrc/autograd/generated/python_functions_4.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", - "torch/csrc/autograd/generated/python_torch_functions.cpp", + "torch/csrc/autograd/generated/python_torch_functions_0.cpp", + "torch/csrc/autograd/generated/python_torch_functions_1.cpp", + "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_nn_functions.cpp", "torch/csrc/autograd/generated/python_fft_functions.cpp", "torch/csrc/autograd/generated/python_linalg_functions.cpp", @@ -392,9 +394,6 @@ filegroup( "aten/src/THC/THCStorageCopy.cu.cc", "aten/src/THC/THCTensor.cu.cc", "aten/src/THC/THCTensorCopy.cu.cc", - "aten/src/THC/THCTensorMath.cu.cc", - "aten/src/THC/THCTensorMathMagma.cu.cc", - "aten/src/THC/THCTensorMathPairwise.cu.cc", "aten/src/THC/THCTensorMathScan.cu.cc", "aten/src/THC/THCTensorScatterGather.cu.cc", "aten/src/THC/THCTensorSort.cu.cc", @@ -409,22 +408,6 @@ filegroup( ], ) -filegroup( - name = "thcunn_srcs_cu", - srcs = [ - "aten/src/THCUNN/BCECriterion.cu.cc", - "aten/src/THCUNN/ELU.cu.cc", - "aten/src/THCUNN/HardTanh.cu.cc", - "aten/src/THCUNN/LeakyReLU.cu.cc", - "aten/src/THCUNN/MultiMarginCriterion.cu.cc", - "aten/src/THCUNN/SoftMarginCriterion.cu.cc", - "aten/src/THCUNN/SoftPlus.cu.cc", - "aten/src/THCUNN/SoftShrink.cu.cc", - "aten/src/THCUNN/SpatialConvolutionMM.cu.cc", - "aten/src/THCUNN/Tanh.cu.cc", - ], -) - filegroup( name = "aten_srcs_cu", srcs = [ @@ -574,8 +557,6 @@ cc_library( "aten/src/THC/**/*.cpp", "aten/src/THC/*.cuh", "aten/src/THC/generic/*.cu.cc", - "aten/src/THCUNN/*.cuh", - "aten/src/THCUNN/generic/*.cu.cc", ], exclude = [ "aten/src/ATen/Config.h", @@ -717,7 +698,6 @@ cu_library( srcs = [ ":aten_srcs_cu", ":thc_srcs_cu", - ":thcunn_srcs_cu", ], copts = ATEN_COPTS + torch_cuda_half_options, visibility = ["//visibility:public"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 188f35a9981e0..f5eed7207f107 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.5 FATAL_ERROR) +cmake_minimum_required(VERSION 3.10 FATAL_ERROR) #cmake_policy(SET CMP0022 NEW) #cmake_policy(SET CMP0023 NEW) @@ -202,6 +202,7 @@ cmake_dependent_option( "USE_CUDNN" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) +option(USE_BREAKPAD "Use breakpad crash dump library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF) option(USE_FAKELOWP "Use FakeLowp operators" OFF) option(USE_FFMPEG "Use ffmpeg" OFF) @@ -213,6 +214,7 @@ option(USE_LMDB "Use LMDB" OFF) option(USE_MAGMA "Use MAGMA" ON) option(USE_METAL "Use Metal for Caffe2 iOS build" ON) option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) +option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option( USE_MLCOMPUTE "Use ML Compute for macOS build" ON @@ -264,6 +266,10 @@ if(NOT DEFINED USE_VULKAN) "ANDROID" OFF) endif() +if(IOS) + set(USE_BREAKPAD OFF) +endif() + option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON) option(USE_LITE_INTERPRETER_PROFILER "Enable " ON) option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF) @@ -318,9 +324,9 @@ option(WERROR "Build with -Werror supported by the compiler" OFF) if(USE_CCACHE) find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) - set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") - set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") - set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") else() message(STATUS "Could not find ccache. Consider installing ccache to speed up compilation.") endif() @@ -683,6 +689,10 @@ if(USE_PYTORCH_METAL) string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL") endif() +if(USE_PYTORCH_METAL_EXPORT) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL_EXPORT") +endif() + if(USE_SOURCE_DEBUG_ON_MOBILE) string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE") endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7d8659a8babff..e102de7be6334 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -87,7 +87,7 @@ lazy.) ```bash -conda -y uninstall pytorch +conda uninstall pytorch -y yes | pip uninstall torch ``` @@ -197,6 +197,7 @@ with `brew install cmake` if you are developing on MacOS or Linux system. Could not find .../pytorch/third_party/pybind11/CMakeLists.txt ``` remove any `submodule.*` settings in your local git config (`.git/config` of your pytorch repo) and try again. +* If you're a Windows contributor, please check out [Best Practices](https://github.com/pytorch/pytorch/wiki/Best-Practices-to-Edit-and-Compile-Pytorch-Source-Code-On-Windows). ## Nightly Checkout & Pull @@ -242,8 +243,7 @@ into the repo directory. * [aten](aten) - C++ tensor library for PyTorch (no autograd support) * [src](aten/src) - [README](aten/src/README.md) * [TH](aten/src/TH) - [THC](aten/src/THC) - [THCUNN](aten/src/THCUNN) - Legacy library code from the original + [THC](aten/src/THC) - Legacy library code from the original Torch. Try not to add things here; we're slowly porting these to [native](aten/src/ATen/native). * generic - Contains actual implementations of operators, @@ -435,12 +435,12 @@ is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME`, where `TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is the suite that test is defined in. -For example, if you wanted to run the test ` MayContainAlias`, which +For example, if you wanted to run the test `MayContainAlias`, which is part of the test suite `ContainerAliasingTest` in the file `test/cpp/jit/test_alias_analysis.cpp`, the command would be: ```bash -./build/bin/test_jit --gtest_filter=ContainerAliasingTest.UnionAliasing +./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias ``` @@ -735,116 +735,54 @@ succeed. #### Use CCache -Even when dependencies are tracked with file modification, -there are many situations where files get rebuilt when a previous -compilation was exactly the same. - -Using ccache in a situation like this is a real time-saver. The ccache manual -describes [two ways to use ccache](https://ccache.samba.org/manual/latest.html#_run_modes). -In the PyTorch project, currently only the latter method of masquerading as -the compiler via symlinks works for CUDA compilation. - -Here are the instructions for installing ccache from source (tested at commit -`3c302a7` of the `ccache` repo): +Even when dependencies are tracked with file modification, there are many +situations where files get rebuilt when a previous compilation was exactly the +same. Using ccache in a situation like this is a real time-saver. +Before building pytorch, install ccache from your package manager of choice: ```bash -#!/bin/bash - -if ! ls ~/ccache/bin/ccache -then - set -ex - sudo apt-get update - sudo apt-get install -y cmake - mkdir -p ~/ccache - pushd ~/ccache - rm -rf ccache - git clone https://github.com/ccache/ccache.git - mkdir -p ccache/build - pushd ccache/build - cmake -DCMAKE_INSTALL_PREFIX=${HOME}/ccache -DENABLE_TESTING=OFF -DZSTD_FROM_INTERNET=ON .. - make -j$(nproc) install - popd - popd - - mkdir -p ~/ccache/lib - mkdir -p ~/ccache/cuda - ln -s ~/ccache/bin/ccache ~/ccache/lib/cc - ln -s ~/ccache/bin/ccache ~/ccache/lib/c++ - ln -s ~/ccache/bin/ccache ~/ccache/lib/gcc - ln -s ~/ccache/bin/ccache ~/ccache/lib/g++ - ln -s ~/ccache/bin/ccache ~/ccache/cuda/nvcc - - ~/ccache/bin/ccache -M 25Gi -fi - -export PATH=~/ccache/lib:$PATH -export CUDA_NVCC_EXECUTABLE=~/ccache/cuda/nvcc +conda install ccache -f conda-forge +sudo apt install ccache +sudo yum install ccache +brew install ccache ``` -Alternatively, `ccache` provided by newer Linux distributions (e.g. Debian/sid) -also works, but the `nvcc` symlink to `ccache` as described above is still required. - -Note that the original `nvcc` binary (typically at `/usr/local/cuda/bin`) must -be on your `PATH`, otherwise `ccache` will emit the following error: - - ccache: error: Could not find compiler "nvcc" in PATH - -For example, here is how to install/configure `ccache` on Ubuntu: +You may also find the default cache size in ccache is too small to be useful. +The cache sizes can be increased from the command line: ```bash -# install ccache -sudo apt install ccache - -# update symlinks and create/re-create nvcc link -sudo /usr/sbin/update-ccache-symlinks -sudo ln -s /usr/bin/ccache /usr/lib/ccache/nvcc - # config: cache dir is ~/.ccache, conf file ~/.ccache/ccache.conf # max size of cache ccache -M 25Gi # -M 0 for unlimited # unlimited number of files ccache -F 0 - -# deploy (and add to ~/.bashrc for later) -export PATH="/usr/lib/ccache:$PATH" ``` -It is also possible to install `ccache` via `conda` by installing it from the -community-maintained `conda-forge` channel. Here is how to set up `ccache` this -way: +To check this is working, do two clean builds of pytorch in a row. The second +build should be substantially and noticeably faster than the first build. If +this doesn't seem to be the case, check the `CMAKE__COMPILER_LAUNCHER` +rules in `build/CMakeCache.txt`, where `` is `C`, `CXX` and `CUDA`. +Each of these 3 variables should contain ccache, e.g. +``` +//CXX compiler launcher +CMAKE_CXX_COMPILER_LAUNCHER:STRING=/usr/bin/ccache +``` +If not, you can define these variables on the command line before invoking `setup.py`. ```bash -# install ccache -conda install -c conda-forge ccache - -# set up ccache compiler symlinks -mkdir ~/ccache -mkdir ~/ccache/lib -mkdir ~/ccache/cuda -ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/cc -ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/c++ -ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/gcc -ln -s $CONDA_PREFIX/bin/ccache ~/ccache/lib/g++ -ln -s $CONDA_PREFIX/bin/ccache ~/ccache/cuda/nvcc - -# update PATH to reflect symlink locations, consider -# adding this to your .bashrc -export PATH=~/ccache/lib:$PATH -export CUDA_NVCC_EXECUTABLE=~/ccache/cuda/nvcc - -# increase ccache cache size to 25 GiB -ccache -M 25Gi +export CMAKE_C_COMPILER_LAUNCHER=ccache +export CMAKE_CXX_COMPILER_LAUNCHER=ccache +export CMAKE_CUDA_COMPILER_LAUNCHER=ccache +python setup.py develop ``` -To check this is working, do two clean builds of pytorch in a row. The second -build should be substantially and noticeably faster than the first build. If this doesn't seem to be the case, check that each of the symlinks above actually link to your installation of `ccache`. For example, if you followed the first option and installed `ccache` from source on a Linux machine, running `readlink -e $(which g++)` should return `~/ccache/bin/ccache`. - - #### Use a faster linker If you are editing a single file and rebuilding in a tight loop, the time spent linking will dominate. The system linker available in most Linux distributions (GNU `ld`) is quite slow. Use a faster linker, like [lld](https://lld.llvm.org/). +People on Mac, follow [this guide](https://stackoverflow.com/questions/42730345/how-to-install-llvm-for-mac) instead. + The easiest way to use `lld` this is download the [latest LLVM binaries](http://releases.llvm.org/download.html#8.0.0) and run: ``` diff --git a/README.md b/README.md index 53ebfb1a4bec6..ed793fb8874e6 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ We hope you never spend hours debugging your code because of bad stack traces or PyTorch has minimal framework overhead. We integrate acceleration libraries such as [Intel MKL](https://software.intel.com/mkl) and NVIDIA ([cuDNN](https://developer.nvidia.com/cudnn), [NCCL](https://developer.nvidia.com/nccl)) to maximize speed. At the core, its CPU and GPU Tensor and neural network backends -(TH, THC, THNN, THCUNN) are mature and have been tested for years. +are mature and have been tested for years. Hence, PyTorch is quite fast – whether you run small or large neural networks. @@ -291,9 +291,10 @@ You can refer to the [build_pytorch.bat](https://github.com/pytorch/pytorch/blob ```cmd cmd -:: [Optional] If you want to build with the VS 2017 generator for old CUDA and PyTorch, please change the value in the next line to `Visual Studio 15 2017`. -:: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`. -set CMAKE_GENERATOR=Visual Studio 16 2019 +:: Set the environment variables after you have downloaded and upzipped the mkl package, +:: else CMake would throw error as `Could NOT find OpenMP`. +set CMAKE_INCLUDE_PATH={Your directory}\mkl\include +set LIB={Your directory}\mkl\lib;%LIB% :: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. diff --git a/WORKSPACE b/WORKSPACE index 6f5028d4d0912..9396a3451c360 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,7 +1,7 @@ workspace(name = "pytorch") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//tools/rules:workspace.bzl", "new_patched_local_repository") +load("//tools/rules:workspace.bzl", "new_patched_local_repository", "new_empty_repository") http_archive( name = "bazel_skylib", @@ -170,3 +170,14 @@ protobuf_deps() load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() + +local_repository( + name = "local_config_cuda", + path = "third_party/tensorflow_cuda_bazel_build", +) + +# Wrapper to expose local_config_cuda in an agnostic way +new_empty_repository( + name = "cuda", + build_file = "//third_party:cuda.BUILD", +) diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 400b00f8e858a..7ba92a6decee7 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -80,21 +80,14 @@ if(USE_ROCM) # ATen proper) set(AT_CUDA_ENABLED 1) add_subdirectory(src/THH) - add_subdirectory(src/THHUNN) message("ROCm is enabled.") elseif(USE_CUDA) set(AT_CUDA_ENABLED 1) add_subdirectory(src/THC) - add_subdirectory(src/THCUNN) else() message("disabling CUDA because USE_CUDA is set false") set(AT_CUDA_ENABLED 0) endif() -if(NOT USE_CUDA) - # we still parse THCUNN even if cuda is disabled to make sure to - # install it - install(FILES src/THCUNN/generic/THCUNN.h DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THCUNN/generic") -endif() if(NOT USE_NNPACK) set(AT_NNPACK_ENABLED 0) diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index 09c8cdb6c095a..4270ec021dbc7 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -6,7 +6,38 @@ // Defines the accumulation type for a scalar type. // Example: -// using accscalar_t = acc_type; +// using accscalar_t = acc_type; +// +// Accumulation types are an important concept in numeric computing +// because you frequently want to perform intermediate computations +// at a higher precision than the input and output precision, to avoid +// compounding internal rounding errors. Accumulation is the most +// well-known intermediate computation (it is of great importance for +// sum reduction and matrix multiply, for example), but in PyTorch +// acc_type ends up getting used for all sorts of other intermediate +// computations, so it perhaps would be more accurately (ahem) called an +// "accurate" type. acc_type is especially important for reduced +// precision operations like float16 and bfloat16, where relatively +// benign looking inputs can easily end up overflowing/underflowing. +// +// acc_type is parametrized by whether or not you are running on CUDA +// or not, because on CUDA double precision operations are expensive +// and so by default, we don't actually want to use double as an +// acc_type on CUDA. A lot of things are typed out below, but +// basically, the table is generated by a few rules: +// +// If bool: +// Use 'bool' as acc_type. +// If floating point: +// If CUDA, use 'float' as acc_type (unless scalar_t is double), +// otherwise (CPU) use 'double' +// If integral: +// Use 'int64_t' as acc_type +// +// You're not forced to use this template; if you happen to know +// something specific about your use case, you can specify your own +// desired behavior. This template, however, will give you a reasonable +// default that will work for all dtypes supported in PyTorch. #if defined(__CUDACC__) #include diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index a73f3e31ff894..114d970bf7ddc 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -167,13 +167,12 @@ else() endif() # Metal -if(USE_PYTORCH_METAL) - if(APPLE) - set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs}) - else() - # Add files needed from optimized_for_mobile - set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${metal_prepack_cpp}) - endif() +if(USE_PYTORCH_METAL_EXPORT) + # Add files needed from exporting metal models(optimized_for_mobile) + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${metal_prepack_cpp}) +elseif(APPLE AND USE_PYTORCH_METAL) + # Compile Metal kernels + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs}) else() set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp}) endif() @@ -450,13 +449,21 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS}) if(NOT INTERN_BUILD_MOBILE) list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h}) + # Metal + if(USE_PYTORCH_METAL_EXPORT) + # Add files needed from exporting metal models(optimized_for_mobile) + list(APPEND INSTALL_HEADERS ${metal_h} ${metal_prepack_h}) + elseif(APPLE AND USE_PYTORCH_METAL) + # Needed by Metal kernels + list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h}) + else() + list(APPEND INSTALL_HEADERS ${metal_h}) + endif() else() - if(USE_PYTORCH_METAL) - if(IOS) + if(IOS AND USE_PYTORCH_METAL) list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h}) - else() + else() list(APPEND INSTALL_HEADERS ${metal_h} ${metal_prepack_h}) - endif() endif() endif() diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h index f8b43a04c73c0..69dbb8b8de973 100644 --- a/aten/src/ATen/CPUGeneratorImpl.h +++ b/aten/src/ATen/CPUGeneratorImpl.h @@ -10,7 +10,7 @@ namespace at { struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { // Constructors CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); - ~CPUGeneratorImpl() = default; + ~CPUGeneratorImpl() override = default; // CPUGeneratorImpl methods std::shared_ptr clone() const; diff --git a/aten/src/ATen/ConjugateFallback.cpp b/aten/src/ATen/ConjugateFallback.cpp index 3ae9859f2d618..2cf9538c9bb32 100644 --- a/aten/src/ATen/ConjugateFallback.cpp +++ b/aten/src/ATen/ConjugateFallback.cpp @@ -56,6 +56,21 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) { m.impl("view", torch::CppFunction::makeFallthrough()); m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); m.impl("reshape", torch::CppFunction::makeFallthrough()); + m.impl("dot", torch::CppFunction::makeFallthrough()); + m.impl("vdot", torch::CppFunction::makeFallthrough()); + m.impl("dot.out", torch::CppFunction::makeFallthrough()); + m.impl("vdot.out", torch::CppFunction::makeFallthrough()); + m.impl("alias", torch::CppFunction::makeFallthrough()); + m.impl("mm", torch::CppFunction::makeFallthrough()); + m.impl("mm.out", torch::CppFunction::makeFallthrough()); + m.impl("addmm", torch::CppFunction::makeFallthrough()); + m.impl("addmm_", torch::CppFunction::makeFallthrough()); + m.impl("addmm.out", torch::CppFunction::makeFallthrough()); + m.impl("bmm", torch::CppFunction::makeFallthrough()); + m.impl("bmm.out", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm_", torch::CppFunction::makeFallthrough()); + m.impl("baddbmm.out", torch::CppFunction::makeFallthrough()); } } // namespace at diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 26f1d11f92b48..4a45ac6f8ac18 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -79,6 +80,9 @@ class TORCH_API Context { static bool hasMLC() { return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC); } + static bool hasORT() { + return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT); + } // defined in header so that getNonVariableType has ability to inline // call_once check. getNonVariableType is called fairly frequently THCState* lazyInitCUDA() { @@ -292,6 +296,10 @@ static inline bool hasMLC() { return globalContext().hasMLC(); } +static inline bool hasORT() { + return globalContext().hasORT(); +} + // Despite its name, this function returns the number of *CUDA* GPUs. static inline size_t getNumGPUs() { // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h deleted file mode 100644 index 5670f31a089d9..0000000000000 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10 { -class Scalar; -} -namespace at { -struct Generator; -class Tensor; -struct Type; -} // namespace at - -namespace at { -namespace native { -namespace legacy { -namespace cuda { - -std::tuple _th_gels_out(const Tensor & self, const Tensor & A, Tensor & res1, Tensor & res2); -std::tuple _th_gels(const Tensor & self, const Tensor & A); -Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); -Tensor _th_potri(const Tensor & self, bool upper); -Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src); -Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, const Scalar& lower, const Scalar& upper, bool training); -std::tuple _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones); -std::tuple _thnn_conv2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const optional & bias, IntArrayRef stride, IntArrayRef padding); -std::tuple _thnn_conv2d_backward_out(Tensor & grad_input, Tensor & grad_weight, Tensor & grad_bias, const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones); -std::tuple _thnn_conv2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones, std::array output_mask); - -} // namespace th -} // namespace legacy -} // namespace native -} // namespace at diff --git a/aten/src/ATen/OpMathType.h b/aten/src/ATen/OpMathType.h new file mode 100644 index 0000000000000..b58d4779ac7a4 --- /dev/null +++ b/aten/src/ATen/OpMathType.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at { + +// For FP16 or BFloat16 inputs, ops should perform internal math in FP32. +template struct OpMathType { using type = scalar_t; }; +template<> struct OpMathType { using type = float; }; +template<> struct OpMathType { using type = float; }; + +template +using opmath_type = typename OpMathType::type; + +} // namespace at diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index 17b4b20aa9bd0..565c979e35e16 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -222,6 +222,7 @@ void set_num_threads(int nthreads) { } int get_num_threads() { + at::internal::lazy_init_num_threads(); #ifndef C10_MOBILE // not initializing pool unnecessarily, // because pool cannot be resized after initialization diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp index 15040498edc5c..c38dcb64f81bd 100644 --- a/aten/src/ATen/ParallelNativeTBB.cpp +++ b/aten/src/ATen/ParallelNativeTBB.cpp @@ -66,6 +66,7 @@ void set_num_threads(int nthreads) { } int get_num_threads() { + at::internal::lazy_init_num_threads(); return tbb::global_control::active_value( tbb::global_control::max_allowed_parallelism); } diff --git a/aten/src/ATen/ParallelThreadPoolNative.cpp b/aten/src/ATen/ParallelThreadPoolNative.cpp index 2670c7bd08d1b..cc5821d494a25 100644 --- a/aten/src/ATen/ParallelThreadPoolNative.cpp +++ b/aten/src/ATen/ParallelThreadPoolNative.cpp @@ -57,6 +57,7 @@ void set_num_interop_threads(int nthreads) { } int get_num_interop_threads() { + at::internal::lazy_init_num_threads(); int nthreads = num_interop_threads.load(); if (nthreads > 0) { return nthreads; diff --git a/aten/src/ATen/TensorMeta.h b/aten/src/ATen/TensorMeta.h index ac295ec9bde79..6a5491ab3d50b 100644 --- a/aten/src/ATen/TensorMeta.h +++ b/aten/src/ATen/TensorMeta.h @@ -26,6 +26,16 @@ namespace impl { #define TORCH_META_FUNC(name) void structured_##name::meta #define TORCH_META_FUNC2(name, overload) void structured_##name##_##overload::meta +// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct as a return value. +// They should be used when the kernel in question has precomputed values declared in native_functions.yaml and +// the corresponding implementation should return an instance of the aforementioned struct. +#define TORCH_PRECOMPUTE_META_FUNC(name) structured_##name::meta_return_ty structured_##name::meta +#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) structured_##name##_##overload::meta_return_ty structured_##name##_##overload::meta + +// Use this to create a precompute struct in a meta function. +#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<> +#define TORCH_PRECOMPUTE_STRUCT2(name, overload) structured_##name##_##overload::precompute_out<> + // Use this to define the prototype for an implementation. This takes only // one argument, which is the name of the dispatch key entry you're // implementing. diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index af9a8a1b22153..1ec9f9c291c0a 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -282,7 +282,6 @@ bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) { return contig_if_nonempty; } -// Correspond to THCUNN_check_dim_size/THNN_check_dim_size void check_dim_size( const Tensor& tensor, int64_t dim, diff --git a/aten/src/ATen/TensorUtils.h b/aten/src/ATen/TensorUtils.h index 8e84ecaa4a3a2..1417174a1f6d3 100644 --- a/aten/src/ATen/TensorUtils.h +++ b/aten/src/ATen/TensorUtils.h @@ -144,7 +144,6 @@ TORCH_API void* maybe_data_ptr(const TensorArg& tensor); // on whether a subgeometry is contiguous. TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); -// Correspond to THCUNN_check_dim_size/THNN_check_dim_size TORCH_API void check_dim_size( const Tensor& tensor, int64_t dim, diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index ba7be1a06b8a1..98c2519e045ce 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -9,30 +9,26 @@ namespace at { -ThreadLocalState::ThreadLocalState(bool keep_grad_mode) +ThreadLocalState::ThreadLocalState() : dispatch_key_(c10::impl::tls_local_dispatch_key_set()), debug_info_(c10::ThreadLocalDebugInfo::current()), - inference_mode_enabled_(c10::InferenceMode::is_enabled()) { + autograd_tls_(c10::AutogradState::get_tls_state()) { rf_tls_ = at::get_record_function_tls_(); saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - keep_grad_mode_ = keep_grad_mode; - if (keep_grad_mode_) { - grad_mode_enabled_ = GradMode::is_enabled(); - } -#endif bumped_record_all_functions_ = at::checkRecordAllFunctions(); } +void ThreadLocalState::set_grad_mode(bool enabled) { + autograd_tls_.set_grad_mode(enabled); +} + /* static */ void ThreadLocalState::setThreadLocalState( const ThreadLocalState& state) { -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - if (state.keep_grad_mode_) { - GradMode::set_enabled(state.grad_mode_enabled_); - } -#endif + // Note that setting the InferenceMode TLS in this function is ONLY ok because we always + // restore the dispatch key set TLS at the same time. + c10::AutogradState::set_tls_state(state.autograd_tls_); at::set_record_function_tls_(state.rf_tls_); @@ -43,8 +39,6 @@ void ThreadLocalState::setThreadLocalState( c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); - - c10::InferenceMode::_set_enabled(state.inference_mode_enabled_); } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index f30f5e3442cc1..41146912819b4 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -16,10 +16,12 @@ class TORCH_API ThreadLocalState { public: // Saves the thread local variables' values and // returns them as a ThreadLocalState - // keep_grad_mode - whether grad mode has to be preserved - // (e.g. not preserved when passing from forward pass into - // the autograd engine, autograd engine takes care of grad mode) - ThreadLocalState(bool keep_grad_mode = true); + ThreadLocalState(); + + // set_grad_mode - force the value of the grad mode TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_grad_mode(bool enabled); // Sets thread local variables in the current thread, // according to the thread boundary specified @@ -35,13 +37,8 @@ class TORCH_API ThreadLocalState { // RecordFunction TLS RecordFunctionTLS rf_tls_; -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - bool keep_grad_mode_ = true; - bool grad_mode_enabled_; -#endif - - // TLS for InferenceMode - bool inference_mode_enabled_; + // TLS for AutogradModes + AutogradState autograd_tls_; // TLS for saved tensors default hooks std::pair saved_tensors_default_hooks_; diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 750c90bb4c59f..0c0ea61ceb3c2 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -184,6 +184,10 @@ std::string show_config() { ss << detail::getCUDAHooks().showConfig(); } + if (hasORT()) { + ss << detail::getORTHooks().showConfig(); + } + ss << " - Build settings: "; for (const auto& pair : caffe2::GetBuildOptions()) { if (!pair.second.empty()) { diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index 2768efe6e683b..13e605c920ec1 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -7,6 +7,9 @@ namespace at { static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) { + // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the range [-1, 0]. + // This is a special case for scalar tensors and manifests in e.g. torch.sum(scalar_tensor, 0) + // Otherwise, dim should be in the range [-dim_post_expr, dim_post_expr-1]. return c10::maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 97ec9ec69dbeb..9f5f486bb7581 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -59,7 +59,7 @@ thread_local int nesting = 0; thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16; // autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU. -at::ScalarType autocast_gpu_dtype = at::kHalf; +thread_local at::ScalarType autocast_gpu_dtype = at::kHalf; } void clear_cache() { @@ -461,22 +461,22 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) - KERNEL_CPU(ADD_NS(_log_softmax), "_log_softmax", Tensor (const Tensor &, int64_t, bool), lower_precision_fp) KERNEL_CPU(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) KERNEL_CPU(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp) KERNEL_CPU(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) KERNEL_CPU(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) KERNEL_CPU(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp) KERNEL_CPU(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional &), lower_precision_fp) + KERNEL_CPU(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp) // fp32 cast policy + KERNEL_CPU(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32) KERNEL_CPU(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32) KERNEL_CPU(ADD_NS(batch_norm), "batch_norm", Tensor (const Tensor &, const c10::optional &, const c10::optional &, const c10::optional &, const c10::optional &, bool, double, double, bool), fp32) - KERNEL_CPU(ADD_NS(max_pool2d), "max_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32) - KERNEL_CPU(ADD_NS(adaptive_avg_pool2d), "adaptive_avg_pool2d", Tensor (const Tensor &, IntArrayRef), fp32) - KERNEL_CPU(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp32) KERNEL_CPU(ADD_NS(dropout), "dropout", Tensor (const Tensor &, double, bool), fp32) + KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32) KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional), fp32) KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32) @@ -492,45 +492,285 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, c10::optional, bool, c10::optional>), fp32) KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional, c10::optional, c10::optional), fp32) KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, c10::optional, bool, c10::optional>), fp32) + KERNEL_CPU(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional&, int64_t), fp32) KERNEL_CPU(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional&, const c10::optional&, int64_t), fp32) - KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Scalar", Tensor (const Tensor &, const Scalar &), fp32) - KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Tensor", Tensor (const Tensor &, const Tensor &), fp32) - KERNEL_CPU(ADD_NS(pow), "pow.Scalar", Tensor (const Scalar&, const Tensor &), fp32) - KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32) - KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor (const Tensor &, IntArrayRef), fp32) - KERNEL_CPU(ADD_NS(std), "std", Tensor (const Tensor &, bool), fp32) - KERNEL_CPU(ADD_NS(std), "std.dim", Tensor (const Tensor &, IntArrayRef, bool, bool), fp32) KERNEL_CPU(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional&, const c10::optional&, const c10::optional&, const c10::optional&, bool, double, double, bool), fp32) + KERNEL_CPU(ADD_NS(grid_sampler), "grid_sampler", Tensor(const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32) + KERNEL_CPU(ADD_NS(polar), "polar", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(multinomial), "multinomial", Tensor(const Tensor &, int64_t, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(poisson), "poisson", Tensor(const Tensor &, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fmod), "fmod.Tensor", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(fmod), "fmod.Scalar", Tensor(const Tensor &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(prod), "prod", Tensor(const Tensor &, c10::optional), fp32) + KERNEL_CPU(ADD_NS(prod), "prod.dim_int", Tensor(const Tensor &, int64_t, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(prod), "prod.dim_Dimname", Tensor(const Tensor &, at::Dimname, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(quantile), "quantile", Tensor(const Tensor &, const Tensor &, c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(quantile), "quantile.new", Tensor(const Tensor &, const Tensor &, c10::optional, bool, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(quantile), "quantile.new_scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new", Tensor(const Tensor &, const Tensor &, c10::optional, bool, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.new_scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cumprod), "cumprod.dimname", Tensor(const Tensor &, at::Dimname, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cumsum), "cumsum", Tensor(const Tensor &, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(cumsum), "cumsum.dimname", Tensor(const Tensor &, at::Dimname, c10::optional), fp32) + KERNEL_CPU(ADD_NS(diag), "diag", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(diagflat), "diagflat", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(histc), "histc", Tensor(const Tensor &, int64_t, const at::Scalar &, const at::Scalar &), fp32) + KERNEL_CPU(ADD_NS(logcumsumexp), "logcumsumexp", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Tensor", Tensor(const Tensor &, const Tensor &, bool, bool), fp32) + KERNEL_CPU(ADD_NS(searchsorted), "searchsorted.Scalar", Tensor(const Tensor &, const at::Scalar &, bool, bool), fp32) + KERNEL_CPU(ADD_NS(trace), "trace", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(tril), "tril", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(triu), "triu", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(vander), "vander", Tensor(const Tensor &, c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(view_as_complex), "view_as_complex", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(cholesky), "cholesky", Tensor(const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(cholesky_inverse), "cholesky_inverse", Tensor(const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(cholesky_solve), "cholesky_solve", Tensor(const Tensor &, const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(dot), "dot", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(inverse), "inverse", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(lu_solve), "lu_solve", Tensor(const Tensor &, const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(matrix_rank), "matrix_rank", Tensor(const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(orgqr), "orgqr", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(ormqr), "ormqr", Tensor(const Tensor &, const Tensor &, const Tensor &, bool, bool), fp32) + KERNEL_CPU(ADD_NS(pinverse), "pinverse", Tensor(const Tensor &, double), fp32) + KERNEL_CPU(ADD_NS(vdot), "vdot", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(im2col), "im2col", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(col2im), "col2im", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(max_pool3d), "max_pool3d", Tensor(const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32) + KERNEL_CPU(ADD_NS(max_unpool2d), "max_unpool2d", Tensor(const Tensor &, const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(max_unpool3d), "max_unpool3d", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(adaptive_avg_pool3d), "adaptive_avg_pool3d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(reflection_pad2d), "reflection_pad2d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(replication_pad1d), "replication_pad1d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(replication_pad2d), "replication_pad2d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(replication_pad3d), "replication_pad3d", Tensor(const Tensor &, IntArrayRef), fp32) + KERNEL_CPU(ADD_NS(elu), "elu", Tensor(const Tensor &, const Scalar &, const Scalar &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(hardshrink), "hardshrink", Tensor(const Tensor &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(hardsigmoid), "hardsigmoid", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(hardswish), "hardswish", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(log_sigmoid), "log_sigmoid", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(prelu), "prelu", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(selu), "selu", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(celu), "celu", Tensor(const Tensor &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(softplus), "softplus", Tensor(const Tensor &, const Scalar &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(softshrink), "softshrink", Tensor(const Tensor &, const Scalar &), fp32) + KERNEL_CPU(ADD_NS(group_norm), "group_norm", Tensor(const Tensor &, int64_t, const c10::optional &, const c10::optional &, double, bool), fp32) + KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32) + KERNEL_CPU(ADD_NS(mse_loss), "mse_loss", Tensor(const Tensor &, const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.IntList", Tensor(const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t, int64_t, bool), fp32) + KERNEL_CPU(ADD_NS(ctc_loss), "ctc_loss.Tensor", Tensor(const Tensor &, const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32) + KERNEL_CPU(ADD_NS(kl_div), "kl_div", Tensor(const Tensor &, const Tensor &, int64_t, bool), fp32) + KERNEL_CPU(ADD_NS(multilabel_margin_loss), "multilabel_margin_loss", Tensor(const Tensor &, const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(fft_fft), "fft_fft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_ifft), "fft_ifft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_fft2), "fft_fft2", Tensor(const Tensor &, c10::optional, at::IntArrayRef, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_ifft2), "fft_ifft2", Tensor(const Tensor &, c10::optional, at::IntArrayRef, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_fftn), "fft_fftn", Tensor(const Tensor &, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_ifftn), "fft_ifftn", Tensor(const Tensor &, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_rfft), "fft_rfft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_irfft), "fft_irfft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_rfft2), "fft_rfft2", Tensor(const Tensor &, c10::optional, at::IntArrayRef, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_irfft2), "fft_irfft2", Tensor(const Tensor &, c10::optional, at::IntArrayRef, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_rfftn), "fft_rfftn", Tensor(const Tensor &, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_irfftn), "fft_irfftn", Tensor(const Tensor &, c10::optional, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_hfft), "fft_hfft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(fft_ihfft), "fft_ihfft", Tensor(const Tensor &, c10::optional, int64_t, c10::optional), fp32) + KERNEL_CPU(ADD_NS(conv_tbc), "conv_tbc", Tensor(const Tensor &, const Tensor &, const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(linalg_matrix_norm), "linalg_matrix_norm", Tensor(const Tensor &, const at::Scalar &, at::IntArrayRef, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(linalg_matrix_norm), "linalg_matrix_norm.str_ord", Tensor(const Tensor &, c10::string_view, at::IntArrayRef, bool, c10::optional), fp32) + KERNEL_CPU(ADD_NS(linalg_cond), "linalg_cond", Tensor(const Tensor &, const c10::optional &), fp32) + KERNEL_CPU(ADD_NS(linalg_cond), "linalg_cond.p_str", Tensor(const Tensor &, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank", Tensor(const Tensor &, const c10::optional, bool), fp32) + KERNEL_CPU(ADD_NS(linalg_matrix_rank), "linalg_matrix_rank.tol_tensor", Tensor(const Tensor &, const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(linalg_solve), "linalg_solve", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(linalg_cholesky), "linalg_cholesky", Tensor(const Tensor &, bool), fp32) + KERNEL_CPU(ADD_NS(linalg_svdvals), "linalg_svdvals", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(linalg_eigvals), "linalg_eigvals", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(linalg_eigvalsh), "linalg_eigvalsh", Tensor(const Tensor &, c10::string_view), fp32) + KERNEL_CPU(ADD_NS(linalg_inv), "linalg_inv", Tensor(const Tensor &), fp32) + KERNEL_CPU(ADD_NS(linalg_householder_product), "linalg_householder_product", Tensor(const Tensor &, const Tensor &), fp32) + KERNEL_CPU(ADD_NS(linalg_tensorinv), "linalg_tensorinv", Tensor(const Tensor &, int64_t), fp32) + KERNEL_CPU(ADD_NS(linalg_tensorsolve), "linalg_tensorsolve", Tensor(const Tensor &, const Tensor &, c10::optional), fp32) KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32) + KERNEL_CPU(ADD_NS(glu), "glu", Tensor (const Tensor &, int64_t), fp32) - // promote - KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote) - KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) + m.impl(TORCH_SELECTIVE_NAME("aten::cummax"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t), + std::tuple (const Tensor &, int64_t), + &ADD_NS(cummax)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::cummax.dimname"), + TORCH_FN((&WrapFunction (const Tensor &, at::Dimname), + std::tuple (const Tensor &, at::Dimname), + &ADD_NS(cummax)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::cummin"), + TORCH_FN((&WrapFunction (const Tensor &, int64_t), + std::tuple (const Tensor &, int64_t), + &ADD_NS(cummin)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::cummin.dimname"), + TORCH_FN((&WrapFunction (const Tensor &, at::Dimname), + std::tuple (const Tensor &, at::Dimname), + &ADD_NS(cummin)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::eig"), + TORCH_FN((&WrapFunction (const Tensor &, bool), + std::tuple (const Tensor &, bool), + &ADD_NS(eig)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::geqrf"), + TORCH_FN((&WrapFunction (const Tensor &), + std::tuple (const Tensor &), + &ADD_NS(geqrf)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::lstsq"), + TORCH_FN((&WrapFunction (const Tensor &, const Tensor &), + std::tuple (const Tensor &, const Tensor &), + &ADD_NS(lstsq)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::_lu_with_info"), + TORCH_FN((&WrapFunction (const Tensor &, bool, bool), + std::tuple (const Tensor &, bool, bool), + &ADD_NS(_lu_with_info)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::lu_unpack"), + TORCH_FN((&WrapFunction (const Tensor &, const Tensor &, bool, bool), + std::tuple (const Tensor &, const Tensor &, bool, bool), + &ADD_NS(lu_unpack)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::qr"), + TORCH_FN((&WrapFunction (const Tensor &, bool), + std::tuple (const Tensor &, bool), + &ADD_NS(qr)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::solve"), + TORCH_FN((&WrapFunction (const Tensor &, const Tensor &), + std::tuple (const Tensor &, const Tensor &), + &ADD_NS(solve)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::svd"), + TORCH_FN((&WrapFunction (const Tensor &, bool, bool), + std::tuple (const Tensor &, bool, bool), + &ADD_NS(svd)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::symeig"), + TORCH_FN((&WrapFunction (const Tensor &, bool, bool), + std::tuple (const Tensor &, bool, bool), + &ADD_NS(symeig)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::triangular_solve"), + TORCH_FN((&WrapFunction (const Tensor &, const Tensor &, bool, bool, bool), + std::tuple (const Tensor &, const Tensor &, bool, bool, bool), + &ADD_NS(triangular_solve)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::fractional_max_pool2d"), + TORCH_FN((&WrapFunction (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &), + std::tuple (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &), + &ADD_NS(fractional_max_pool2d)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::fractional_max_pool3d"), + TORCH_FN((&WrapFunction (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &), + std::tuple (const Tensor &, IntArrayRef, IntArrayRef, const Tensor &), + &ADD_NS(fractional_max_pool3d)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool1d"), + TORCH_FN((&WrapFunction (const Tensor &, IntArrayRef), + std::tuple (const Tensor &, IntArrayRef), + &ADD_NS(adaptive_max_pool1d)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool2d"), + TORCH_FN((&WrapFunction (const Tensor &, IntArrayRef), + std::tuple (const Tensor &, IntArrayRef), + &ADD_NS(adaptive_max_pool2d)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::adaptive_max_pool3d"), + TORCH_FN((&WrapFunction (const Tensor &, IntArrayRef), + std::tuple (const Tensor &, IntArrayRef), + &ADD_NS(adaptive_max_pool3d)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::multilabel_margin_loss_forward"), + TORCH_FN((&WrapFunction (const Tensor &, const Tensor &, int64_t), + std::tuple (const Tensor &, const Tensor &, int64_t), + &ADD_NS(multilabel_margin_loss_forward)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_qr"), + TORCH_FN((&WrapFunction (const Tensor &, c10::string_view), + std::tuple (const Tensor &, c10::string_view), + &ADD_NS(linalg_qr)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_cholesky_ex"), + TORCH_FN((&WrapFunction (const Tensor &, bool, bool), + std::tuple (const Tensor &, bool, bool), + &ADD_NS(linalg_cholesky_ex)>::type::call))); - m.impl(TORCH_SELECTIVE_NAME("aten::topk"), + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_svd"), TORCH_FN((&WrapFunction (const Tensor &, int64_t, int64_t, bool, bool), - std::tuple (const Tensor &, int64_t, int64_t, bool, bool), - &ADD_NS(topk)>::type::call))); + std::tuple (const Tensor &, bool), + std::tuple (const Tensor &, bool), + &ADD_NS(linalg_svd)>::type::call))); - m.impl(TORCH_SELECTIVE_NAME("aten::sort"), + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_eig"), TORCH_FN((&WrapFunction (const Tensor &, int64_t, bool), - std::tuple (const Tensor &, int64_t, bool), - &ADD_NS(sort)>::type::call))); + std::tuple (const Tensor &), + std::tuple (const Tensor &), + &ADD_NS(linalg_eig)>::type::call))); - m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue"), + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_eigh"), TORCH_FN((&WrapFunction (const Tensor &, int64_t, int64_t, bool), - std::tuple (const Tensor &, int64_t, int64_t, bool), - &ADD_NS(kthvalue)>::type::call))); + std::tuple (const Tensor &, c10::string_view), + std::tuple (const Tensor &, c10::string_view), + &ADD_NS(linalg_eigh)>::type::call))); - m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue.dimname"), + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_lstsq"), TORCH_FN((&WrapFunction (const Tensor &, int64_t, at::Dimname, bool), - std::tuple (const Tensor &, int64_t, at::Dimname, bool), - &ADD_NS(kthvalue)>::type::call))); + std::tuple (const Tensor &, const Tensor &, c10::optional, c10::optional), + std::tuple (const Tensor &, const Tensor &, c10::optional, c10::optional), + &ADD_NS(linalg_lstsq)>::type::call))); + + m.impl(TORCH_SELECTIVE_NAME("aten::linalg_inv_ex"), + TORCH_FN((&WrapFunction (const Tensor &, bool), + std::tuple (const Tensor &, bool), + &ADD_NS(linalg_inv_ex)>::type::call))); + + // promote + KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote) + KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote) + KERNEL_CPU(ADD_NS(index_copy), "index_copy", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &), promote) + KERNEL_CPU(ADD_NS(index_copy), "index_copy.dimname", Tensor (const Tensor &, at::Dimname, const Tensor &, const Tensor &), promote) + } } // namespace } // namespace autocast diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index baf1691bd1d53..dbbed6e3b0785 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -232,6 +232,9 @@ void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) } } +void print(const Tensor & t, int64_t linesize) { + print(std::cout,t,linesize); +} std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesize) { FormatGuard guard(stream); if(!tensor_.defined()) { diff --git a/aten/src/ATen/core/Formatting.h b/aten/src/ATen/core/Formatting.h index 86ea603951613..55cfe7b3bdf7e 100644 --- a/aten/src/ATen/core/Formatting.h +++ b/aten/src/ATen/core/Formatting.h @@ -2,7 +2,7 @@ #include #include -#include +#include namespace c10 { @@ -18,9 +18,7 @@ TORCH_API std::ostream& print( static inline std::ostream& operator<<(std::ostream & out, const Tensor & t) { return print(out,t,80); } -static inline void print(const Tensor & t, int64_t linesize=80) { - print(std::cout,t,linesize); -} +TORCH_API void print(const Tensor & t, int64_t linesize=80); static inline std::ostream& operator<<(std::ostream & out, Scalar s) { if (s.isFloatingPoint()) { diff --git a/aten/src/ATen/core/MT19937RNGEngine.h b/aten/src/ATen/core/MT19937RNGEngine.h index 033df304e4a8e..40c1ba5f584ad 100644 --- a/aten/src/ATen/core/MT19937RNGEngine.h +++ b/aten/src/ATen/core/MT19937RNGEngine.h @@ -157,7 +157,6 @@ class mt19937_engine { data_.state_[0] = seed & 0xffffffff; for(int j = 1; j < MERSENNE_STATE_N; j++) { data_.state_[j] = (1812433253 * (data_.state_[j-1] ^ (data_.state_[j-1] >> 30)) + j); - data_.state_[j] &= 0xffffffff; } data_.left_ = 1; data_.next_ = 0; diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index fb0f86952bea4..fa2479c800c05 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -6,28 +6,22 @@ namespace at { class TORCH_API OptionalTensorRef { public: - OptionalTensorRef() {} + OptionalTensorRef() = default; ~OptionalTensorRef() { ref_.unsafeReleaseTensorImpl(); } OptionalTensorRef(const Tensor& src) - : ref_(c10::intrusive_ptr( - src.unsafeGetTensorImpl(), - c10::raw::DontIncreaseRefcount{})) { + : ref_(Tensor::unsafe_borrow_t{}, src) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); } OptionalTensorRef(const OptionalTensorRef& rhs) - : OptionalTensorRef(rhs.ref_) {} + : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} - OptionalTensorRef& operator=(const OptionalTensorRef& rhs) { - // Need to call unsafeReleaseTensorImpl on ref_ since we are reassigning it - // (which does not call the destructor). - ref_.unsafeReleaseTensorImpl(); - ref_ = Tensor(c10::intrusive_ptr( - rhs.ref_.unsafeGetTensorImpl(), c10::raw::DontIncreaseRefcount{})); + OptionalTensorRef& operator=(OptionalTensorRef rhs) { + std::swap(ref_, rhs.ref_); return *this; } @@ -39,6 +33,14 @@ class TORCH_API OptionalTensorRef { return ref_; } + const Tensor& operator*() const & { + return ref_; + } + + const Tensor* operator->() const & { + return &ref_; + } + operator bool() const { return ref_.defined(); } diff --git a/aten/src/ATen/core/Vitals.cpp b/aten/src/ATen/core/Vitals.cpp index edff5211ea0f0..76fc652f9407e 100644 --- a/aten/src/ATen/core/Vitals.cpp +++ b/aten/src/ATen/core/Vitals.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace at { namespace vitals { diff --git a/aten/src/ATen/core/Vitals.h b/aten/src/ATen/core/Vitals.h index c64cf7e629210..48913c54185f3 100644 --- a/aten/src/ATen/core/Vitals.h +++ b/aten/src/ATen/core/Vitals.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include #include +#include #include #include diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 584e3db9ee193..df6b860a8a363 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -36,7 +36,6 @@ _(aten, _cast_Half) \ _(aten, _cast_Int) \ _(aten, _cast_Long) \ _(aten, _cast_Short) \ -_(aten, _cat) \ _(aten, _ceil) \ _(aten, _clamp_max) \ _(aten, _clamp_min) \ @@ -224,7 +223,6 @@ _(aten, bmm) \ _(aten, broadcast_tensors) \ _(aten, broadcast_to) \ _(aten, cartesian_prod) \ -_(aten, cat) \ _(aten, cauchy) \ _(aten, ceil) \ _(aten, celu) \ @@ -405,6 +403,7 @@ _(aten, is_complex) \ _(aten, is_contiguous) \ _(aten, is_cuda) \ _(aten, is_mlc) \ +_(aten, is_ort) \ _(aten, is_distributed) \ _(aten, is_floating_point) \ _(aten, is_inference) \ @@ -454,7 +453,6 @@ _(aten, margin_ranking_loss) \ _(aten, masked_fill) \ _(aten, masked_scatter) \ _(aten, masked_select) \ -_(aten, matmul) \ _(aten, matrix_rank) \ _(aten, matrix_exp) \ _(aten, max) \ diff --git a/aten/src/ATen/core/builtin_function.h b/aten/src/ATen/core/builtin_function.h index de30f9b7e179f..600c16bb6e5d4 100644 --- a/aten/src/ATen/core/builtin_function.h +++ b/aten/src/ATen/core/builtin_function.h @@ -123,7 +123,7 @@ struct BuiltinOpFunction : public Function { return *this; } - ~BuiltinOpFunction() {} + ~BuiltinOpFunction() override {} private: c10::QualifiedName name_; diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index fd32a72c75102..cfa6b740f8877 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -344,6 +344,10 @@ class TORCH_API OperatorHandle { c10::Dispatcher::singleton().callBoxed(*this, stack); } + void callBoxed(Stack& stack) const { + callBoxed(&stack); + } + void redispatchBoxed(DispatchKeySet ks, Stack* stack) const { c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack); } diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index cc6de61dccead..a4319f03132cc 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -1,5 +1,7 @@ #include +#include + namespace c10 { void FunctionSchema::dump() const { diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 68e177a225d76..f4b11fc4a304a 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -34,6 +35,9 @@ struct Argument { default_value_(std::move(default_value)), kwarg_only_(kwarg_only), alias_info_(std::move(alias_info)) { + // this is an softly-enforced invariant for out arguments. + bool is_alias = alias_info_.has_value() && alias_info_.value().isWrite(); + is_out_ = kwarg_only_ && is_alias; } const std::string& name() const { return name_; @@ -50,6 +54,11 @@ struct Argument { bool kwarg_only() const { return kwarg_only_; } + + bool is_out() const { + return is_out_; + } + const c10::optional& alias_info() const { return alias_info_; } @@ -116,6 +125,8 @@ struct Argument { // is this only specifiable as a keyword argument? bool kwarg_only_; c10::optional alias_info_; + // marks if the argument is out variant of the schema + bool is_out_; }; inline bool operator==(const Argument& lhs, const Argument& rhs) { @@ -262,7 +273,7 @@ struct FunctionSchema { }); } - c10::optional argumentIndexWithName(const std::string& name) const { + c10::optional argumentIndexWithName(c10::string_view name) const { for(size_t i = 0; i < arguments().size(); ++i) { if(name == arguments()[i].name()) return i; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 168ecb4f3dc17..6e26e8c14cdab 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -51,6 +51,16 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) return out; } +inline size_t findFirstOutArg(const std::vector& args) { + // find the start of out args in the schema + for (size_t out_start_idx = 0; out_start_idx < args.size(); out_start_idx++) { + if (args.at(out_start_idx).is_out()) { + return out_start_idx; + } + } + return args.size(); +} + inline bool Argument::isBackwardCompatibleWith( const Argument& old, std::ostream* why_not) const { @@ -121,17 +131,20 @@ inline bool FunctionSchema::isBackwardCompatibleWith( } } - // Make sure that all the old arguments have their corresponding backward - // compatible arguments in this schema. - for (size_t i = 0; i < old.arguments().size(); ++i) { + // we want to test both out and default args seperately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + // make sure among the default args, they are backward compatible + for (size_t i = 0; i < old_out_start_idx; i++) { if (!arguments().at(i).isBackwardCompatibleWith( old.arguments().at(i), why_not)) { return false; } } - // Validate that all new arguments provided a default value. - for (size_t i = old.arguments().size(); i < arguments().size(); ++i) { + // // Validate that all new arguments provided has a default value + for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { if (!arguments().at(i).default_value()) { if (why_not) { *why_not @@ -144,6 +157,15 @@ inline bool FunctionSchema::isBackwardCompatibleWith( } } + // now compare the out args + for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { + return false; + } + } + return true; } diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 2f527cdde5e91..e7aef155a5656 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -84,6 +84,7 @@ namespace c10 { _(prim, NumToTensor) \ _(prim, Uninitialized) \ _(prim, VarConcat) \ + _(prim, VarStack) \ _(prim, With) \ _(prim, Enter) \ _(prim, Exit) \ @@ -210,6 +211,8 @@ namespace c10 { _(aten, linalg_norm) \ _(aten, linalg_vector_norm) \ _(aten, linalg_matrix_norm) \ + _(aten, matmul) \ + _(aten, linalg_matmul) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ @@ -305,6 +308,9 @@ namespace c10 { _(aten, bin) \ _(aten, pop) \ _(aten, insert) \ + _(aten, _cat) \ + _(aten, cat) \ + _(aten, concat) \ _(aten, vstack) \ _(aten, row_stack) \ _(prim, unchecked_unwrap_optional) \ @@ -466,7 +472,8 @@ namespace c10 { _(attr, keepdims) \ _(attr, cache_id) \ _(attr, new_axis) \ - _(attr, warn_id) + _(attr, warn_id) \ + _(attr, allowzero) // 'prim' symbols are synthetic operators that occur only in the IR // and don't have corresponding implementations in ATen. diff --git a/aten/src/ATen/core/interned_strings_class.h b/aten/src/ATen/core/interned_strings_class.h index 54303e0384d28..8bbf3294844a5 100644 --- a/aten/src/ATen/core/interned_strings_class.h +++ b/aten/src/ATen/core/interned_strings_class.h @@ -1,8 +1,6 @@ #include #include -#include #include -#include #include #include #include diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6fab54ff9dd82..b81c50f063b19 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace c10 { bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) { @@ -945,36 +946,25 @@ getClassConverter() { } // Needs to be in this .cpp file to access the full definition of PyObjectHolder -std::vector> ivalue::Future::extractStorages( - const at::IValue& value) { +std::vector> ivalue::Future:: + extractStorages(const at::IValue& value) { std::vector> weakStorageImpls; // getSubValues works poorly on Python objects: it only works if they can be // converted to a "regular" IValue type hence, for example, it doesn't support // custom subclasses. Thus, instead, we extract the tensors through pickling. + // Sparse tensors do not have storage. Instead, a sparse tensor + // contains two tensors indices and values, and both contain storage. if (value.isPyObject()) { std::vector tensors = value.toPyObjectHolder()->extractTensors(); - size_t num_storages = 0; - for (const at::Tensor& tensor : tensors) { + weakStorageImpls.reserve(2 * tensors.size()); + for (const auto& tensor : tensors) { if (tensor.is_sparse()) { - // Sparse tensor is indices and values. Both are tensors - // and contain storage. Therefore num_storages needs to be - // incremented by 2. - num_storages += 2; + weakStorageImpls.push_back( + tensor._indices().storage().getWeakStorageImpl()); + weakStorageImpls.push_back( + tensor._values().storage().getWeakStorageImpl()); } else { - // A dense/strided tensor contains 1 storage. - num_storages += 1; - } - } - weakStorageImpls.reserve(num_storages); - for (const at::Tensor& tensor : tensors) { - if (tensor.is_sparse()) { - // Sparse tensor is indices and values. Both are tensors - // and contain storage. - weakStorageImpls.push_back(tensor.indices().storage().getWeakStorageImpl()); - weakStorageImpls.push_back(tensor.values().storage().getWeakStorageImpl()); - } else { - // A dense/strided tensor contains 1 storage weakStorageImpls.push_back(tensor.storage().getWeakStorageImpl()); } } @@ -985,7 +975,15 @@ std::vector> ivalue::Future::extractSt value.getSubValues(sub_values); for (const at::IValue& sub_value : sub_values) { if (sub_value.isTensor()) { - weakStorageImpls.push_back(sub_value.toTensor().storage().getWeakStorageImpl()); + auto& tensor = sub_value.toTensor(); + if (tensor.is_sparse()) { + weakStorageImpls.push_back( + tensor._indices().storage().getWeakStorageImpl()); + weakStorageImpls.push_back( + tensor._values().storage().getWeakStorageImpl()); + } else { + weakStorageImpls.push_back(tensor.storage().getWeakStorageImpl()); + } } } } diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 188a619307185..6574187d06f8b 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -872,14 +872,17 @@ struct TORCH_API IValue final { struct HashAliasedIValue { size_t operator()(const IValue& val) const { if (val.isTensor()) { - if (val.toTensor().is_mkldnn()) { + auto& tensor = val.toTensor(); + if (tensor.is_mkldnn() || tensor.is_sparse()) { // MKLDNN tensors dont have storage and dont create views // or aliasing so we can just use Tensor pointer, TODO: find way // to use mkldnn storage - return reinterpret_cast(val.toTensor().unsafeGetTensorImpl()); + // Sparse tensors don't have storage use unsafeGetTensorImpl + // instead of using the storage of indices or values. + return reinterpret_cast(tensor.unsafeGetTensorImpl()); } else { return reinterpret_cast( - val.toTensor().storage().unsafeGetStorageImpl()); + tensor.storage().unsafeGetStorageImpl()); } } // If it is not a Tensor, then two mutable IValues alias each other only diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index d733fbd2da5b1..4284e296229cc 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -9,10 +9,11 @@ #include #include -#include +#include #include +#include +#include #include -#include struct ClassType; namespace torch { @@ -29,6 +30,9 @@ struct FunctionSchema; struct NamedType; using OptNameList = c10::optional>; +void standardizeVectorForUnion(std::vector& reference, std::vector* to_fill); +void standardizeVectorForUnion(std::vector* to_flatten); + struct AnyType; using AnyTypePtr = std::shared_ptr; // Any is the top of the type hierarchy, all other types are subtypes @@ -93,25 +97,84 @@ struct SingleElementType : public Type { TypePtr elem; }; +struct UnionType; +using UnionTypePtr = std::shared_ptr; +struct TORCH_API UnionType : public Type { + friend struct Type; + + static const TypeKind Kind = TypeKind::UnionType; + + bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override; + + std::string str() const override; + + static UnionTypePtr create(std::vector reference); + + bool operator==(const Type& rhs) const override; + + at::ArrayRef containedTypes() const override { + return types_; + } + + // For testing purposes only + at::ArrayRef getTypes() const { + return types_; + } + + TypePtr createWithContained(std::vector contained_types) const override { + return create(contained_types); + } + + bool canHoldType(TypePtr type) const; + + bool hasFreeVariables() const override { + return has_free_variables_; + } + + c10::optional toOptional() const; + + c10::optional subtractTypeSet(std::vector& to_subtract) const; + + protected: + explicit UnionType(std::vector types, TypeKind kind=TypeKind::UnionType); + std::string annotation_str_impl(TypePrinter printer = nullptr) const override; + std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool has_free_variables_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector types_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool can_hold_none_; + +}; + struct OptionalType; using OptionalTypePtr = std::shared_ptr; -// This type represents an optional type, for each element type. -// Optional[T] can accept both T and None(nullopt in C++) +// This type represents an optional type. There is one `Optional` for +// each element type. `Optional[T]` can accept both `T` and +// `None`(`c10::nullopt` in C++) // Subtype hierarchy for Optional: -// 1. Optional[T] <: Optional[R] iff T <: R -// 2. T <: Optional[R] if T <: R -// 3. None <: Optional[T] for all T -struct TORCH_API OptionalType - : public SingleElementType { - static OptionalTypePtr create(TypePtr element) { - TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr"); - // Optional is a union of [None, T], so Optional[[Optional[T]]] -> - // Optional[T] - if (auto opt_ptr = element->cast()) { - return opt_ptr; - } - return OptionalTypePtr( - new OptionalType(std::move(element))); // NOLINT(modernize-make-shared) +// - Optional[T] <: Optional[R] iff T <: R +// - T <: Optional[R] if T <: R +// - None <: Optional[T] for all T +// - Optional[T] == Union[T, None] for all T +struct TORCH_API OptionalType : public UnionType { + static OptionalTypePtr create(TypePtr contained) { + return OptionalTypePtr(new OptionalType(std::move(contained))); + } + + static const TypeKind Kind = TypeKind::OptionalType; + + friend struct Type; + + bool operator==(const Type& rhs) const override; + + TypePtr getElementType() const { + return contained_; + } + + at::ArrayRef containedTypes() const override { + return contained_; } std::string str() const override { @@ -126,20 +189,15 @@ struct TORCH_API OptionalType return create(contained_types[0]); } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { - if (Type::isSubtypeOfExt(rhs, why_not)) { - return true; - } - if (auto rhs_ = rhs->cast()) { - return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not); - } - return false; - } + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + // common cast Optional[Tensor] for undefined tensor type static OptionalTypePtr ofTensor(); private: - OptionalType(TypePtr elem) : SingleElementType(elem) {} + explicit OptionalType(TypePtr contained); + + TypePtr contained_; std::string annotation_str_impl(TypePrinter printer = nullptr) const override { std::stringstream ss; @@ -907,7 +965,6 @@ struct TORCH_API RRefType } }; - struct NamedType; using NamedTypePtr = std::shared_ptr; using ConstNamedTypePtr = std::shared_ptr; @@ -1111,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType { std::weak_ptr<::torch::jit::CompilationUnit> cu_; }; - // the common supertype of all Enums, only used in operator registraion. // EnumType <: AnyEnumType for all Enums struct AnyEnumType; @@ -1131,7 +1187,6 @@ struct TORCH_API AnyEnumType : public Type { : Type(TypeKind::AnyEnumType) {} }; - struct NumberType; using NumberTypePtr = std::shared_ptr; // This type represents a Python number @@ -1140,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr; // FloatType <: NumberType // ComplexType <:NumberType struct TORCH_API NumberType : public Type { - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } + bool operator==(const Type& rhs) const override; + + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override; + std::string str() const override { return "Scalar"; // match what PythonArgParser says for clarity } @@ -1171,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType { return "float"; } bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { - return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not); + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::FloatType; // global singleton @@ -1195,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType { return "complex"; } bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { - return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not); + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::ComplexType; // global singleton @@ -1219,7 +1277,8 @@ struct TORCH_API IntType : public NumberType { return "int"; } bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override { - return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not); + // NOLINTNEXTLINE(bugprone-parent-virtual-call) + return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); } static const TypeKind Kind = TypeKind::IntType; // global singleton @@ -1333,12 +1392,8 @@ struct TORCH_API NoneType : public Type { std::string str() const override { return "NoneType"; } - bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override { - if (rhs->kind() == OptionalType::Kind) { - return true; - } - return Type::isSubtypeOfExt(rhs, why_not); - } + bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override; + static const TypeKind Kind = TypeKind::NoneType; // global singleton static NoneTypePtr get(); @@ -1523,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); // what is the type, ignoring extra size/shape information? // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) -// xxx: be careful with calls because this can be very slow. If calling this on a graph -// use `EraseShapeInformation` in shape_analysis.h +// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor +// subtypes as simply "Tensor"; we also create a new version of any +// container types in which internal Tensors have undergone the same +// operation. This is used for type comparisons between two Tensor types +// (`unshapedType` means that we don't falsely return `false` for e.g. +// Tensors of different dimensions). It's also used in the alias +// analysis pass. +// Be careful with calls because this can be very slow. If calling this +// on a graph, use `EraseShapeInformation` in shape_analysis.h inline TypePtr unshapedType(const TypePtr& type) { if (type->isSubtypeOf(TensorType::get())) { return TensorType::get(); @@ -1568,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { return *result; } -// Attempt to find the correct supertype of t1 and t2. If none is found then -// nullopt will be returned if default_to_any is false, and Any will be returned -// if it is true. If t1 == t2, or t1 is a type refinement of t2, -// then t2 will be returned (and vice versa). +// Attempt to find the correct supertype of the two types `t1` and `t2`. +// If no supertype is found, then nullopt will be returned if +// `default_to_union` is false, and `Union[t1, t2]` will be returned +// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`, +// then `t2` will be returned (and vice versa). +// // Two different tensortypes will return dynamic. -// Currently we chose not to support returning a NumberType for a float & int -// input because of a lack of operator support for NumberType. +// +// Currently we chose not to support returning a NumberType for +// two types from the set of {FloatType, IntType, ComplexType}, because +// there is a lack of operator support for NumberType. +// // If `type_hint` is an `InterfaceType`, then we can use that as a // potential supertype for `ClassType`s in the list. Otherwise, we have // no way to find and use some common interface type TORCH_API c10::optional unifyTypes( const TypePtr& t1, const TypePtr& t2, - bool default_to_any = false, - TypePtr type_hint=nullptr); + bool default_to_union = false, + TypePtr type_hint = nullptr); TORCH_API c10::optional unifyTypeList( at::ArrayRef elements, std::ostream& why_not, - bool default_to_any=false, - TypePtr type_hint=nullptr); + bool default_to_union = false, + TypePtr type_hint = nullptr); namespace detail { template diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index dbb4a62f73088..a9be1e8d68658 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -21,7 +21,7 @@ namespace c10 { _(DictType) \ _(NumberType) \ _(FloatType) \ - _(ComplexType) \ + _(ComplexType) \ _(FutureType) \ _(RRefType) \ _(IntType) \ @@ -44,7 +44,8 @@ namespace c10 { _(ScalarTypeType) \ _(AnyListType) \ _(AnyTupleType) \ - _(AnyClassType) + _(AnyClassType) \ + _(UnionType) enum class TypeKind { #define DEFINE_TYPE(T) T, @@ -203,7 +204,7 @@ struct TORCH_API Type : std::enable_shared_from_this { // contained_types TypePtr withContained(std::vector contained_types) { auto current_contained = containedTypes(); - AT_ASSERT(current_contained.size() == contained_types.size()); + TORCH_INTERNAL_ASSERT(current_contained.size() == contained_types.size()); if (current_contained.equals(contained_types)) { return shared_from_this(); } diff --git a/aten/src/ATen/core/op_registration/README.md b/aten/src/ATen/core/op_registration/README.md index edd9f911cd0e1..5605e962a6e5e 100644 --- a/aten/src/ATen/core/op_registration/README.md +++ b/aten/src/ATen/core/op_registration/README.md @@ -13,13 +13,13 @@ There’s four main use cases * You’re writing a new operator that isn’t supposed to be part of the public PyTorch API. * You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators. * You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files. -* You’re writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`. +* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`. For these use cases, the custom operator API is the better solution. ### What is the price for using the custom operator API instead of `native_functions.yaml`? -If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats. +If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats. * It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator. * The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`. diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h index ffc0e8fd9037d..021e8a02104f2 100644 --- a/aten/src/ATen/core/stack.h +++ b/aten/src/ATen/core/stack.h @@ -1,6 +1,9 @@ #pragma once +#include + #include +#include // TODO move this to c10 namespace @@ -9,7 +12,42 @@ namespace jit { using c10::IValue; using Stack = std::vector; -using Operation = std::function; + +class Operation { + template + using accepts = std::is_constructible, F&&>; + + public: + template ::value, int> = 0> + C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.") + Operation(F&& raw): op_([raw = std::forward(raw)](Stack& stack) { + raw(&stack); + }) {} + + template ::value && + !std::is_same, Operation>::value, int> = 0> + Operation(F&& op): op_(std::forward(op)) {} + + Operation(std::nullptr_t) noexcept {} + + explicit operator bool() const noexcept { + return op_ ? true : false; + } + + void operator()(Stack& stack) { + op_(stack); + } + + template + T* target() noexcept { + return op_.target(); + } + + private: + std::function op_; +}; // An operation with N inputs and M outputs pops the last N inputs off // the stack and pushes its M inputs onto the stack diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 4214f4d3e1f6f..fec0cb086ee51 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() { return value; } -c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) { +c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_union=false, TypePtr type_hint=nullptr) { // check direct subtyping relation if (t1->isSubtypeOf(t2)) { return t2; @@ -308,7 +308,7 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool } std::vector elements; for (size_t i = 0; i < tuple1->elements().size(); i++) { - if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) { + if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_union)) { elements.push_back(*elem); } else { return c10::nullopt; @@ -347,11 +347,11 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool return c10::nullopt; } -c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) { - auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint); +c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_union, TypePtr type_hint) { + auto unified = unifyTypesImpl(t1, t2, default_to_union, type_hint); - if (default_to_any && !unified) { - return AnyType::get(); + if (default_to_union && !unified) { + return UnionType::create({t1, t2}); } return unified; @@ -360,7 +360,7 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def c10::optional unifyTypeList( at::ArrayRef elements, std::ostream& why_not, - bool default_to_any, + bool default_to_union, TypePtr type_hint) { if (elements.size() == 0) { why_not << "Cannot get unified type from empty list"; @@ -369,7 +369,7 @@ c10::optional unifyTypeList( TypePtr ret_type = elements.at(0); for (size_t i = 1; i < elements.size() && ret_type; ++i) { - c10::optional maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint); + c10::optional maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_union, type_hint); if (!maybe_unified) { why_not << "Could not unify type list since element " << i << " of type " << elements.at(i)->repr_str() @@ -547,8 +547,9 @@ TORCH_API TypePtr tryEvalTypeVariables(TypePtr type, std::unordered_mapkind() == OptionalType::Kind || - elem_type->kind() == NumberType::Kind) { + if (elem_type->kind() == UnionType::Kind + || elem_type->kind() == OptionalType::Kind + || elem_type->kind() == NumberType::Kind) { // Builtin Union types return false; } @@ -577,8 +578,16 @@ bool Type::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { if (rhs->kind() == TypeKind::AnyType || *this == *rhs) { return true; } - if(auto rhs_ = rhs->cast()) { - return this->isSubtypeOfExt(rhs_->getElementType(), why_not); + if (auto opt_rhs = rhs->cast()) { + return this->isSubtypeOfExt(opt_rhs->getElementType(), why_not); + } + if (auto union_rhs = rhs->cast()) { + // Check if `this` is a subtype of any of the types within the Union + return std::any_of(union_rhs->containedTypes().begin(), + union_rhs->containedTypes().end(), + [&](TypePtr inner) { + return this->isSubtypeOfExt(inner, why_not); + }); } return false; } @@ -802,12 +811,459 @@ TupleTypePtr TupleType::createNamed(const c10::optional& qua auto schema = std::make_shared( /*name=*/qualName.value_or(c10::QualifiedName()).name(), /*overload_name=*/std::string(""), - /*arguments=*/arguments, + /*arguments=*/std::move(arguments), /*returns=*/std::vector{}); return std::shared_ptr(new TupleType( field_types, qualName, schema)); // NOLINT(modernize-make-shared) } +bool NoneType::isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const { + if (rhs->kind() == OptionalType::Kind) { + return true; + } + return Type::isSubtypeOfExt(rhs, why_not); +} + +// Remove nested Optionals/Unions during the instantiation of a Union or +// an Optional. This populates `types` with all the types found during +// flattening. At the end of `flattenUnion`, `types` may have +// duplicates, but it will not have nested Optionals/Unions +void flattenUnion(TypePtr& type, std::vector* to_fill) { + if (auto union_type = type->cast()) { + for (auto inner : union_type->containedTypes()) { + flattenUnion(inner, to_fill); + } + } else if (auto opt_type = type->cast()) { + auto inner = opt_type->getElementType(); + flattenUnion(inner, to_fill); + to_fill->emplace_back(NoneType::get()); + } else if (type->kind() == NumberType::Kind) { + to_fill->emplace_back(IntType::get()); + to_fill->emplace_back(FloatType::get()); + to_fill->emplace_back(ComplexType::get()); + } else { + to_fill->emplace_back(type); + } +} + +// Helper function for `standardizeUnion` +// +// NB: If we have types `T1`, `T2`, `T3`, and `PARENT_T` such that `T1`, +// `T2`, and `T2` are children of `PARENT_T`, then `unifyTypes(T1, T2)` +// will return `PARENT_T`. This could be a problem if we didn't want our +// Union to also be able to take `T3 `. In our current type hierarchy, +// this isn't an issue--most types SHOULD be unified even if the parent +// type wasn't in the original vector. However, later additions to the +// type system might necessitate reworking `get_supertype` +void filterDuplicateSubtypes(std::vector* types) { + if (types->empty()) { + return; + } + auto get_supertype = [](const TypePtr t1, const TypePtr t2) -> c10::optional { + // We don't want nested Optionals. Also, prematurely unifying to + // `Optional` could prevent us from coalescing other types + if ((t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) + || (!t1->isSubtypeOf(NoneType::get()) && t2->isSubtypeOf(NoneType::get()))) { + return c10::nullopt; + } else { + return unifyTypes(t1, t2, /*default_to_union=*/false); + } + }; + + // Coalesce types and delete all duplicates. Moving from right to left + // through the vector, we try to unify the current element (`i`) with + // each element (`j`) before the "new" end of the vector (`end`). + // If we're able to unify the types at `types[i]` and `types[j]`, we + // decrement `end`, swap `types[j]` with the unified type, and + // break. Otherwise, we keep `end` where it is to signify that the + // new end of the vector hasn't shifted + size_t end_idx = types->size()-1; + for (size_t i = types->size()-1; i > 0; --i) { + for (size_t j = std::min(i-1, end_idx); ; --j) { + c10::optional unified; + unified = get_supertype((*types)[i], (*types)[j]); + if (unified) { + (*types)[j] = *unified; + (*types)[i] = (*types)[end_idx]; + --end_idx; + break; + } + // Break condition here so we don't get `j = 0; j = j-1` and end + // up with MAX_INT + if (j == 0) { + break; + } + } + } + // Cut off the vector's tail so that `end` is the real last element + types->erase(types->begin() + end_idx + 1, types->end()); + +} + +void sortUnion(std::vector* types) { + // We want the elements to be sorted so we can easily compare two + // UnionType objects for equality in the future. Note that this order + // is guaranteed to be stable since we've already coalesced any + // possible types + std::sort(types->begin(), types->end(), + [](const TypePtr a, const TypePtr b) -> bool { + if (a->kind() != b->kind()) { + return a->kind() < b->kind(); + } + return a->str() < b->str(); + }); +} + +void standardizeVectorForUnion(std::vector& reference, std::vector* to_fill) { + for (auto type : reference) { + flattenUnion(type, to_fill); + } + filterDuplicateSubtypes(to_fill); + sortUnion(to_fill); +} + +void standardizeVectorForUnion(std::vector* to_flatten) { + TORCH_INTERNAL_ASSERT(to_flatten, "`standardizeVectorForUnion` was ", + "passed a `nullptr`"); + std::vector to_fill; + standardizeVectorForUnion(*to_flatten, &to_fill); + *to_flatten = to_fill; +} + +UnionType::UnionType(std::vector reference, TypeKind kind) : Type(kind) { + TORCH_INTERNAL_ASSERT(!reference.empty(), "Cannot create an empty Union"); + + standardizeVectorForUnion(reference, &types_); + + // Gate the assert in a regular conditional so that we don't create + // this long error message unnecessarily + if (types_.size() == 1) { + std::stringstream msg; + msg << "After type unification was performed, the Union with the " + << "original types {"; + for (auto i = 0; i < reference.size(); ++i) { + msg << reference[i]->repr_str(); + if (i > 0) { + msg << ","; + } + msg << " "; + } + msg << "} has the single type " << types_[0]->repr_str() + << ". Use the common supertype instead of creating a Union" + << "type"; + TORCH_INTERNAL_ASSERT(false, msg.str()); + } + + can_hold_none_ = false; + has_free_variables_ = false; + + for (const TypePtr& type : types_) { + if (type->kind() == NoneType::Kind) { + can_hold_none_ = true; + } + if (type->hasFreeVariables()) { + has_free_variables_ = true; + } + } + +} + +UnionTypePtr UnionType::create(std::vector reference) { + auto union_type = new UnionType(std::move(reference)); + + // Some very special-cased logic for `Optional`. This will be deleted + // in a later PR + bool int_found = false; + bool float_found = false; + bool complex_found = false; + bool nonetype_found = false; + + auto update_is_opt_flags = [&](TypePtr t) { + if (t == IntType::get()) { + int_found = true; + } else if (t == FloatType::get()) { + float_found = true; + } else if (t == ComplexType::get()) { + complex_found = true; + } else if (t == NoneType::get()) { + nonetype_found = true; + } + }; + + for (const auto& t : union_type->containedTypes()) { + update_is_opt_flags(t); + } + + bool numbertype_found = int_found && float_found && complex_found; + + if (nonetype_found) { + if (union_type->containedTypes().size() == 4 && numbertype_found) { + return OptionalType::create(NumberType::get()); + } + if (union_type->containedTypes().size() == 2) { + auto not_none = union_type->containedTypes()[0] != NoneType::get() + ? union_type->containedTypes()[0] + : union_type->containedTypes()[1]; + return OptionalType::create(not_none); + } + } + + return UnionTypePtr(union_type); +} + +bool UnionType::operator==(const Type& rhs) const { + if (auto union_rhs = rhs.cast()) { + // We can't compare the type vectors for equality using `operator=`, + // because the vectors hold `TypePtr`s and we want to compare `Type` + // equality + if (union_rhs->containedTypes().size() != this->containedTypes().size()) { + return false; + } + // Check that all the types in `this->types_` are also in + // `union_rhs->types_` + return std::all_of(this->containedTypes().begin(), this->containedTypes().end(), + [&](TypePtr lhs_type) { + return std::any_of(union_rhs->containedTypes().begin(), + union_rhs->containedTypes().end(), + [&](TypePtr rhs_type) { + return *lhs_type == *rhs_type; + }); + }); + } else if (auto optional_rhs = rhs.cast()) { + if (optional_rhs->getElementType() == NumberType::get()) { + return this->containedTypes().size() == 4 + && this->can_hold_none_ + && this->canHoldType(NumberType::get()); + } + auto optional_lhs = this->toOptional(); + return optional_lhs && *optional_rhs == *((optional_lhs.value())->expect()); + } else if (rhs.kind() == NumberType::Kind) { + return this->containedTypes().size() == 3 && canHoldType(NumberType::get()); + } else { + return false; + } +} + +bool UnionType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { + std::vector rhs_types; + if (const auto union_rhs = rhs->cast()) { + // Fast path + if (this->containedTypes() == rhs->containedTypes()) { + return true; + } + rhs_types = rhs->containedTypes().vec(); + } else if (const auto optional_rhs = rhs->cast()) { + rhs_types.push_back(NoneType::get()); + if (optional_rhs->getElementType() == NumberType::get()) { + std::vector number_types{IntType::get(), FloatType::get(), ComplexType::get()}; + rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end()); + } else { + rhs_types.push_back(optional_rhs->getElementType()); + } + } else if (const auto number_rhs = rhs->cast()) { + std::vector number_types{IntType::get(), FloatType::get(), ComplexType::get()}; + rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end()); + } else { + rhs_types.push_back(rhs); + } + return std::all_of(this->containedTypes().begin(), this->containedTypes().end(), + [&](TypePtr lhs_type) -> bool { + return std::any_of(rhs_types.begin(), + rhs_types.end(), + [&](TypePtr rhs_type) -> bool { + return lhs_type->isSubtypeOfExt(rhs_type, why_not); + }); + }); +} + + +std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const { + std::stringstream ss; + + bool can_hold_numbertype = this->canHoldType(NumberType::get()); + + std::vector number_types{IntType::get(), FloatType::get(), ComplexType::get()}; + + auto is_numbertype = [&](TypePtr lhs) { + for (const auto& rhs : number_types) { + if (*lhs == *rhs) { + return true; + } + } + return false; + }; + + ss << "Union["; + bool printed = false; + for (size_t i = 0; i < types_.size(); ++i) { + if (!can_hold_numbertype || !is_numbertype(types_[i])) { + if (i > 0) { + ss << ", "; + printed = true; + } + if (is_annotation_str) { + ss << this->containedTypes()[i]->annotation_str(printer); + } else { + ss << this->containedTypes()[i]->str(); + } + } + } + if (can_hold_numbertype) { + if (printed) { + ss << ", "; + } + if (is_annotation_str) { + ss << NumberType::get()->annotation_str(printer); + } else { + ss << NumberType::get()->str(); + } + } + ss << "]"; + return ss.str(); +} + +std::string UnionType::str() const { + return this->unionStr(nullptr, /*is_annotation_str=*/false); +} + +std::string UnionType::annotation_str_impl(TypePrinter printer) const { + return this->unionStr(printer, /*is_annotation_str=*/true); +} + +bool UnionType::canHoldType(TypePtr type) const { + if (type == NumberType::get()) { + return canHoldType(IntType::get()) + && canHoldType(FloatType::get()) + && canHoldType(ComplexType::get()); + } else { + return std::any_of(this->containedTypes().begin(), this->containedTypes().end(), + [&](TypePtr inner) { + return type->isSubtypeOf(inner); + }); + } +} + +c10::optional UnionType::toOptional() const { + if (!canHoldType(NoneType::get())) { + return c10::nullopt; + } + + std::vector copied_types = this->containedTypes().vec(); + + auto maybe_opt = UnionType::create(std::move(copied_types)); + + if (maybe_opt->kind() == UnionType::Kind) { + return c10::nullopt; + } else { + return maybe_opt; + } +} + +c10::optional UnionType::subtractTypeSet(std::vector& to_subtract) const { + std::vector types; + + // Given a TypePtr `lhs`, this function says whether or not `lhs` (or + // one of its parent types) is in the `to_subtract` vector + auto should_subtract = [&](TypePtr lhs) -> bool { + return std::any_of(to_subtract.begin(), to_subtract.end(), + [&](TypePtr rhs) { + return lhs->isSubtypeOf(rhs); + }); + }; + + // Copy all the elements that should NOT be subtracted to the `types` + // vector + std::copy_if(this->containedTypes().begin(), this->containedTypes().end(), + std::back_inserter(types), + [&](const TypePtr t) { + return !should_subtract(t); + }); + + if (types.size() == 0) { + return c10::nullopt; + } else if (types.size() == 1) { + return types[0]; + } else { + return UnionType::create(std::move(types)); + } +} + +OptionalType::OptionalType(TypePtr contained) + : UnionType({contained, NoneType::get()}, TypeKind::OptionalType) { + bool is_numbertype = false; + if (auto as_union = contained->cast()) { + is_numbertype = as_union->containedTypes().size() == 3 && + as_union->canHoldType(NumberType::get()); + } + if (UnionType::containedTypes().size() == 2) { + contained_ = UnionType::containedTypes()[0]->kind()!= NoneType::Kind + ? UnionType::containedTypes()[0] + : UnionType::containedTypes()[1]; + } else if (contained == NumberType::get() || is_numbertype) { + contained_ = NumberType::get(); + types_.clear(); + types_.push_back(NumberType::get()); + types_.push_back(NoneType::get()); + } else { + std::vector to_subtract{NoneType::get()}; + auto without_none = this->subtractTypeSet(to_subtract); + contained_ = UnionType::create({*without_none}); + } + has_free_variables_ = contained_->hasFreeVariables(); +} + +bool OptionalType::operator==(const Type& rhs) const { + if (auto union_rhs = rhs.cast()) { + auto optional_rhs = union_rhs->toOptional(); + // `**optional_rhs` = `*` to get value of `c10::optional`, + // then `*` to dereference the pointer + return optional_rhs && *this == **optional_rhs; + } else if (auto optional_rhs = rhs.cast()) { + return *this->getElementType() == *optional_rhs->getElementType(); + } else { + return false; + } +} + +bool OptionalType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { + if (OptionalTypePtr optional_rhs = rhs->cast()) { + return getElementType()->isSubtypeOfExt(optional_rhs->getElementType(), why_not); + } else if (UnionTypePtr union_rhs = rhs->cast()) { + if (!union_rhs->canHoldType(NoneType::get())) { + if (why_not) { + *why_not << rhs->repr_str() << " cannot hold None"; + } + return false; + } else if (!union_rhs->canHoldType(this->getElementType())) { + if (why_not) { + *why_not << rhs->repr_str() << " cannot hold " << this->getElementType(); + } + return false; + } else { + return true; + } + } else { + // NOLINTNEXTLINE(bugprone-argument-comment) + return Type::isSubtypeOfExt(rhs, why_not); + } +} + +bool NumberType::operator==(const Type& rhs) const { + if (auto union_type = rhs.cast()) { + return union_type->containedTypes().size() == 3 && union_type->canHoldType(NumberType::get()); + } else { + return rhs.kind() == this->kind(); + } +} + +bool NumberType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { + if (auto union_type = rhs->cast()) { + return union_type->canHoldType(NumberType::get()); + } else { + return Type::isSubtypeOfExt(rhs, why_not); + } +} + TupleType::TupleType( std::vector elements, c10::optional name, @@ -1732,8 +2188,10 @@ size_t ClassType::addAttribute( TORCH_CHECK( (type->kind() == TensorType::Kind) || (type->kind() == OptionalType::Kind && - type->expectRef().getElementType()->kind() == + type->expect()->getElementType()->kind() == TensorType::Kind) || + (type->kind() == UnionType::Kind && + TensorType::get()->isSubtypeOf(type->expect())) || (type->kind() == NoneType::Kind), "Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ", toString(type)); @@ -1880,7 +2338,9 @@ void SymbolicShape::dump() const { bool EnumType::isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const { return rhs->kind() == TypeKind::AnyType || - rhs->kind() == TypeKind::AnyEnumType || *this == *rhs; + rhs->kind() == TypeKind::AnyEnumType || + *this == *rhs || + Type::isSubtypeOfExt(rhs, why_not); } } // namespace c10 diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 0d13458bc4c1c..906d8a8653661 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include namespace at { namespace vec { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 86cf42556d192..5ee9919abca02 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace at { namespace vec { @@ -237,12 +238,6 @@ class Vectorized : public Vectorizedi { std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); } } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&values)[i] << " "; - } - std::cout << std::endl; - } const int32_t& operator[](int idx) const = delete; int32_t& operator[](int idx) = delete; Vectorized abs() const { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index dc5e833127327..8cde485c90d7d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -11,6 +11,7 @@ #include #include +#include // This file defines Vectorized<> for the quantized types. // @@ -309,12 +310,6 @@ struct Vectorized : public Vectorizedqi { return _mm256_add_epi32(rounded, zero_point_v); } - void dump() const { - for (size_t i = 0; i < 8; ++i) { - std::cout << ((int32_t*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor Vectorized(const void* ptr) { @@ -537,12 +532,6 @@ struct Vectorized : public Vectorizedqi { return RequantizeAvx2(inp, multiplier_v, zero_point_v); } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor Vectorized(const void* ptr) { @@ -702,12 +691,6 @@ struct Vectorized : public Vectorizedqi { return RequantizeAvx2(inp, multiplier_v, zero_point_v); } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor @@ -792,13 +775,6 @@ struct VectorizedQuantizedConverter { return rv; } - void dump() const { - for (int i = 0; i < size(); ++i) { - std::cout << vals[i] << " "; - } - std::cout << std::endl; - } - protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index ed457b9adefc8..5b1622e825cb0 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -196,18 +196,6 @@ struct Vectorized { return {veci0, veci1}; } - void dump() const { - std::cout << _vec0[0] << " "; - std::cout << _vec0[1] << " "; - std::cout << _vec0[2] << " "; - std::cout << _vec0[3] << " "; - std::cout << _vec1[0] << " "; - std::cout << _vec1[1] << " "; - std::cout << _vec1[2] << " "; - std::cout << _vec1[3] << " "; - std::cout << std::endl; - } - DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h index f2a8446cd0ed9..82b2530b7ef3f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -361,15 +361,6 @@ struct Vectorized { return {vec0, vec1}; } - void dump() const { - value_type vals[size()]; - store((void*)vals); - for (int i = 0; i < size(); ++i) { - std::cout << (int)(vals[i]) << " "; - } - std::cout << std::endl; - } - DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index cc866c065bfba..f28c14ed3f73f 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -270,12 +270,6 @@ class Vectorized : public Vectorizedi { std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); } } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&values)[i] << " "; - } - std::cout << std::endl; - } const int32_t& operator[](int idx) const = delete; int32_t& operator[](int idx) = delete; Vectorized abs() const { diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 5b5ac195f3caa..3a1eda8874f1a 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -321,12 +321,6 @@ struct Vectorized : public Vectorizedqi { return _mm512_add_epi32(rounded, zero_point_v); } - void dump() const { - for (size_t i = 0; i < 16; ++i) { - std::cout << ((int32_t*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor Vectorized(const void* ptr) { @@ -549,12 +543,6 @@ struct Vectorized : public Vectorizedqi { return RequantizeAvx512(inp, multiplier_v, zero_point_v); } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor Vectorized(const void* ptr) { @@ -714,12 +702,6 @@ struct Vectorized : public Vectorizedqi { return RequantizeAvx512(inp, multiplier_v, zero_point_v); } - void dump() const { - for (size_t i = 0; i < size(); ++i) { - std::cout << (int)((value_type*)&vals)[i] << " "; - } - std::cout << std::endl; - } private: // Load from memory constructor @@ -806,13 +788,6 @@ struct VectorizedQuantizedConverter { return rv; } - void dump() const { - for (int i = 0; i < size(); ++i) { - std::cout << vals[i] << " "; - } - std::cout << std::endl; - } - protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index b9cc47f3fe73b..dbdef0b459928 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -28,7 +28,6 @@ #include #include #include -#include #include #if AT_MKL_ENABLED() && !defined(__APPLE__) diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh index 2617870eea519..2b1538ec15ade 100644 --- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh +++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include @@ -453,13 +453,11 @@ inline bool CUDA_tensor_apply2(at::Tensor a, if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { // Must perform in contiguous space - oldA = a; - a = a.contiguous(); + oldA = std::exchange(a, a.contiguous()); } if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { // Must perform in contiguous space - oldB = b; - b = b.contiguous(); + oldB = std::exchange(b, b.contiguous()); } // It is possible that the tensor dimensions are able to be collapsed, @@ -547,17 +545,11 @@ inline bool CUDA_tensor_apply2(at::Tensor a, #undef HANDLE_A_CASE if (oldA.defined()) { - // Ignore overlaps when copying back; if we use copy - // instead, it will recursively try and invoke ourselves to make - // oldA contiguous. - at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldA, a); + at::native::copy_ignoring_overlaps(oldA, a); } if (oldB.defined()) { - // Ignore overlaps when copying back; if we use copy - // instead, it will recursively try and invoke ourselves to make - // oldB contiguous. - at::native::legacy::cuda::_th_copy_ignoring_overlaps_(oldB, b); + at::native::copy_ignoring_overlaps(oldB, b); } return true; diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 75e59d0ecc100..70c3dda6f3401 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -64,8 +64,8 @@ static void _cublasAdjustLdLevel3( int64_t* lda, int64_t* ldb, int64_t* ldc) { - bool transa_ = ((transa == 't') || (transa == 'T')); - bool transb_ = ((transb == 't') || (transb == 'T')); + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); // Note: leading dimensions generally are checked that they are > 0 // and at least as big the result requires (even if the value won't diff --git a/aten/src/ATen/cuda/CUDAGraphsUtils.cuh b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh index c25ba88a6537c..9d42ed759939b 100644 --- a/aten/src/ATen/cuda/CUDAGraphsUtils.cuh +++ b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh @@ -42,5 +42,18 @@ inline void assertNotCapturing(std::string attempt) { status); } +inline void errorIfCapturingCudnnBenchmark(std::string version_specific) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK(status == CaptureStatus::None, + "Current cudaStreamCaptureStatus: ", + status, + "\nCapturing ", + version_specific, + "is prohibited. Possible causes of this error:\n" + "1. No warmup iterations occurred before capture.\n" + "2. The convolutions you're trying to capture use dynamic shapes, " + "in which case capturing them is generally prohibited."); +} + } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp deleted file mode 100644 index 30c61a3e8b355..0000000000000 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ /dev/null @@ -1,448 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#undef THNN_ -#undef THCIndexTensor_ -#include -#include -#include -#include - -namespace at { -namespace native { -namespace legacy { -namespace cuda { - -namespace { - ScalarType infer_scalar_type(const Tensor & t) { - return t.scalar_type(); - } - ScalarType infer_scalar_type(const TensorList & tl) { - TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); - return tl[0].scalar_type(); - } - - TensorOptions options(ScalarType s) { - return TensorOptions().dtype(s) - .device(DeviceType::CUDA) - .layout(kStrided); - } - - Allocator* allocator() { - return at::cuda::getCUDADeviceAllocator(); - } -} - -std::tuple _th_gels_out(const Tensor & self, const Tensor & A, Tensor & res1, Tensor & res2) { - TORCH_WARN_ONCE( - "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n", - "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in " - "the returned tuple (although it returns other information about the problem).\n", - "To get the qr decomposition consider using torch.linalg.qr.\n", - "The returned solution in torch.lstsq stored the residuals of the solution in the ", - "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ", - "residuals in the field 'residuals' of the returned named tuple.\n", - "The unpacking of the solution, as in\n", - "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n", - "should be replaced with\n", - "X = torch.linalg.lstsq(A, B).solution" - ); - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_); - break; - } - case ScalarType::Float: { - auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels_out", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_); - break; - } - default: - AT_ERROR("_th_gels_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} -std::tuple _th_gels(const Tensor & self, const Tensor & A) { - TORCH_WARN_ONCE( - "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n", - "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in " - "the returned tuple (although it returns other information about the problem).\n", - "To get the qr decomposition consider using torch.linalg.qr.\n", - "The returned solution in torch.lstsq stored the residuals of the solution in the ", - "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ", - "residuals in the field 'residuals' of the returned named tuple.\n", - "The unpacking of the solution, as in\n", - "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n", - "should be replaced with\n", - "X = torch.linalg.lstsq(A, B).solution" - ); - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto res1_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res1 = Tensor(c10::intrusive_ptr::reclaim(res1_)); - auto res2_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto res2 = Tensor(c10::intrusive_ptr::reclaim(res2_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type); - auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type); - auto A_ = checked_dense_tensor_unwrap(A, "A", 2, "_th_gels", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_gels(globalContext().getTHCState(), res1_, res2_, self_, A_); - break; - } - default: - AT_ERROR("_th_gels not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(res1, res2); -} -Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - auto src_ = checked_dense_tensor_unwrap(src, "src", 2, "_th_copy_ignoring_overlaps_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_copyIgnoringOverlaps(globalContext().getTHCState(), self_, src_); - break; - } - default: - AT_ERROR("_th_copy_ignoring_overlaps_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} -std::tuple _thnn_conv2d_forward_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output, Tensor & columns, Tensor & ones) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto output_ = checked_dense_tensor_unwrap(output, "output", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaDoubleSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto output_ = checked_dense_tensor_unwrap(output, "output", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto output_ = checked_dense_tensor_unwrap(output, "output", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaHalfSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto output_ = checked_dense_tensor_unwrap(output, "output", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 6, "_thnn_conv2d_forward_out", false, DeviceType::CUDA, dispatch_scalar_type); - THNN_CudaBFloat16SpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - default: - AT_ERROR("_thnn_conv2d_forward_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(output, columns, ones); -} -std::tuple _thnn_conv2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const c10::optional& bias_opt, IntArrayRef stride, IntArrayRef padding) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - auto output_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto output = Tensor(c10::intrusive_ptr::reclaim(output_)); - auto columns_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto columns = Tensor(c10::intrusive_ptr::reclaim(columns_)); - auto ones_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto ones = Tensor(c10::intrusive_ptr::reclaim(ones_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - THNN_CudaDoubleSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - THNN_CudaSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - THNN_CudaHalfSpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 2, "_thnn_conv2d_forward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 3); - auto bias_ = checked_dense_tensor_unwrap(bias, "bias", 4, "_thnn_conv2d_forward", true, DeviceType::CUDA, dispatch_scalar_type); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - THNN_CudaBFloat16SpatialConvolutionMM_updateOutput(globalContext().getTHCState(), self_, output_, weight_, bias_ ? bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - break; - } - default: - AT_ERROR("_thnn_conv2d_forward not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(output, columns, ones); -} -std::tuple _thnn_conv2d_backward_out(Tensor & grad_input, Tensor & grad_weight, Tensor & grad_bias, const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones) { - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_weight_ = checked_dense_tensor_unwrap(grad_weight, "grad_weight", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_bias_ = checked_dense_tensor_unwrap(grad_bias, "grad_bias", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaDoubleSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaDoubleSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_weight_ = checked_dense_tensor_unwrap(grad_weight, "grad_weight", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_bias_ = checked_dense_tensor_unwrap(grad_bias, "grad_bias", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_weight_ = checked_dense_tensor_unwrap(grad_weight, "grad_weight", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_bias_ = checked_dense_tensor_unwrap(grad_bias, "grad_bias", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaHalfSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaHalfSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::BFloat16: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward_out", false, DeviceType::CUDA, dispatch_scalar_type); - auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_weight_ = checked_dense_tensor_unwrap(grad_weight, "grad_weight", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - auto grad_bias_ = checked_dense_tensor_unwrap(grad_bias, "grad_bias", 8, "_thnn_conv2d_backward_out", true, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaBFloat16SpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaBFloat16SpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - default: - AT_ERROR("_thnn_conv2d_backward_out not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(grad_input, grad_weight, grad_bias); -} -std::tuple _thnn_conv2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones, std::array output_mask) { - const OptionalDeviceGuard device_guard(device_of(self)); - auto dispatch_scalar_type = infer_scalar_type(self); - auto grad_input_ = output_mask[0] ? c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release() : nullptr; - auto grad_input = Tensor(c10::intrusive_ptr::reclaim(grad_input_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*)grad_input_)); - auto grad_weight_ = output_mask[1] ? c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release() : nullptr; - auto grad_weight = Tensor(c10::intrusive_ptr::reclaim(grad_weight_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*)grad_weight_)); - auto grad_bias_ = output_mask[2] ? c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release() : nullptr; - auto grad_bias = Tensor(c10::intrusive_ptr::reclaim(grad_bias_ == nullptr ? (TensorImpl*)UndefinedTensorImpl::singleton() : (TensorImpl*)grad_bias_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaDoubleSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaDoubleSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::Float: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::Half: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaHalfSpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaHalfSpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - case ScalarType::BFloat16: { - auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 3, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto kernel_size_ = check_intlist<2>(kernel_size, "kernel_size", 4); - auto stride_ = check_intlist<2>(stride, "stride", 5); - auto padding_ = check_intlist<2>(padding, "padding", 6); - auto columns_ = checked_dense_tensor_unwrap(columns, "columns", 7, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - auto ones_ = checked_dense_tensor_unwrap(ones, "ones", 8, "_thnn_conv2d_backward", false, DeviceType::CUDA, dispatch_scalar_type); - if (grad_input_) THNN_CudaBFloat16SpatialConvolutionMM_updateGradInput(globalContext().getTHCState(), self_, grad_output_, grad_input_ ? grad_input_ : NULL, weight_, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0]); - if (grad_weight_ || grad_bias_) THNN_CudaBFloat16SpatialConvolutionMM_accGradParameters(globalContext().getTHCState(), self_, grad_output_, grad_weight_ ? grad_weight_ : NULL, grad_bias_ ? grad_bias_ : NULL, columns_, ones_, kernel_size_[1], kernel_size_[0], stride_[1], stride_[0], padding_[1], padding_[0], 1); - break; - } - default: - AT_ERROR("_thnn_conv2d_backward not supported on CUDAType for ", dispatch_scalar_type); - } - return std::tuple(grad_input, grad_weight, grad_bias); -} - -} // namespace th -} // namespace legacy -} // namespace native -} // namespace at diff --git a/aten/src/ATen/cuda/detail/KernelUtils.h b/aten/src/ATen/cuda/detail/KernelUtils.h index 836504a729fea..91a61b04b8590 100644 --- a/aten/src/ATen/cuda/detail/KernelUtils.h +++ b/aten/src/ATen/cuda/detail/KernelUtils.h @@ -2,9 +2,6 @@ #include -// Contents of this file are copied from THCUNN/common.h for the ease of porting -// THCUNN functions into ATen. - namespace at { namespace cuda { namespace detail { // CUDA: grid stride looping diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index 873431c1d96e5..f52280e9d2401 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -2,7 +2,7 @@ #include -#include +#include #include namespace at { namespace native { diff --git a/aten/src/ATen/detail/ORTHooksInterface.cpp b/aten/src/ATen/detail/ORTHooksInterface.cpp new file mode 100644 index 0000000000000..33f70935a04d0 --- /dev/null +++ b/aten/src/ATen/detail/ORTHooksInterface.cpp @@ -0,0 +1,31 @@ +#include + +#include + +#include +#include +#include + +namespace at { +namespace detail { + +// See getCUDAHooks for some more commentary +const ORTHooksInterface& getORTHooks() { + static std::unique_ptr ort_hooks; + static std::once_flag once; + std::call_once(once, [] { + ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {}); + if (!ort_hooks) { + ort_hooks = + // NOLINTNEXTLINE(modernize-make-unique) + std::unique_ptr(new ORTHooksInterface()); + } + }); + return *ort_hooks; +} +} // namespace detail + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs) + +} // namespace at diff --git a/aten/src/ATen/detail/ORTHooksInterface.h b/aten/src/ATen/detail/ORTHooksInterface.h new file mode 100644 index 0000000000000..caee55cdfaf99 --- /dev/null +++ b/aten/src/ATen/detail/ORTHooksInterface.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +constexpr const char* ORT_HELP = + " You need to 'import torch_ort' to use the 'ort' device in PyTorch. " + "The 'torch_ort' module is provided by the ONNX Runtime itself " + "(https://onnxruntime.ai)."; + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +struct TORCH_API ORTHooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~ORTHooksInterface() {} + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API ORTHooksArgs {}; + +C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs); +#define REGISTER_ORT_HOOKS(clsname) \ + C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const ORTHooksInterface& getORTHooks(); +} // namespace detail + +} // namespace at diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp index 6a6476706ac6f..38875191b448b 100644 --- a/aten/src/ATen/miopen/Descriptors.cpp +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -1,6 +1,8 @@ #include #include +#include + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index e6ae3c9ebc3d7..37700bb586793 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -524,7 +524,7 @@ Tensor& rrelu_with_noise_out_cpu(const Tensor& self, c10::optional generator, Tensor& output) { if (training) { - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "rrelu_with_noise_out_cpu", [&] { + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "rrelu_with_noise_out_cpu", [&] { _rrelu_with_noise_train(output, self.contiguous(), noise, lower, upper, generator); }); return output; diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index 01782fae1de3f..f0c6d82af2b29 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -51,7 +51,7 @@ DECLARE_DISPATCH(softshrink_fn, softshrink_stub); DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub); DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub); DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub); -DECLARE_DISPATCH(activation_fn, glu_stub); +DECLARE_DISPATCH(structured_activation_fn, glu_stub); DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub); DECLARE_DISPATCH(structured_activation_fn, silu_stub); DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub); diff --git a/aten/src/ATen/native/AveragePool2d.cpp b/aten/src/ATen/native/AveragePool2d.cpp index 2693cc6ba49c5..8f264c007c6be 100644 --- a/aten/src/ATen/native/AveragePool2d.cpp +++ b/aten/src/ATen/native/AveragePool2d.cpp @@ -8,59 +8,81 @@ namespace at { namespace meta{ using namespace native; -TORCH_META_FUNC(avg_pool2d) ( - const Tensor& input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override -) { +TORCH_PRECOMPUTE_META_FUNC(avg_pool2d) +(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override) { // #20866, #22032: Guarantee this for the official C++ API? TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"); - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); + const int64_t kH = kernel_size[0]; + const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1]; TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2, "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"); - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); + const int64_t dH = stride.empty() ? kH : stride[0]; + const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; TORCH_CHECK(padding.size() == 1 || padding.size() == 2, "avg_pool2d: padding must either be a single int, or a tuple of two ints"); - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + const int64_t padH = padding[0]; + const int64_t padW = padding.size() == 1 ? padH : padding[1]; TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); - /* sizes */ const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; const int64_t nInputPlane = input.size(-3); const int64_t inputHeight = input.size(-2); const int64_t inputWidth = input.size(-1); - const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); - const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + const int64_t outputHeight = pooling_output_shape( + inputHeight, kH, padH, dH, 1, ceil_mode); + const int64_t outputWidth = + pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); auto memory_format = input.suggest_memory_format(); pool2d_shape_check( - input, - kH, kW, dH, dW, padH, padW, 1, 1, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, memory_format); + input, + kH, + kW, + dH, + dW, + padH, + padW, + 1, + 1, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + memory_format); /* resize output */ if (input.ndimension() == 3) { - set_output(0, {nInputPlane, outputHeight, outputWidth}, input.options()); + set_output( + 0, + {nInputPlane, + outputHeight, + outputWidth}, + input.options()); } else { - set_output(0, {nbatch, nInputPlane, outputHeight, outputWidth}, input.options().memory_format(memory_format)); + set_output( + 0, + {nbatch, + nInputPlane, + outputHeight, + outputWidth}, + input.options().memory_format(memory_format)); } + + return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW); } TORCH_META_FUNC(avg_pool2d_backward) ( @@ -119,30 +141,30 @@ TORCH_META_FUNC(avg_pool2d_backward) ( namespace native { -TORCH_IMPL_FUNC(avg_pool2d_out_cpu) ( - const Tensor &input, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override, - const Tensor &output -) { - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - +TORCH_IMPL_FUNC(avg_pool2d_out_cpu) +(const Tensor& input, + int64_t kH, + int64_t kW, + int64_t dH, + int64_t dW, + int64_t padH, + int64_t padW, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + const Tensor& output) { avg_pool2d_kernel( - kCPU, output, input, - kW, kH, dW, dH, padW, padH, - count_include_pad, divisor_override); + kCPU, + output, + input, + kW, + kH, + dW, + dH, + padW, + padH, + count_include_pad, + divisor_override); } TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) ( diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index d80f9184567b1..498b51b38187c 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1549,6 +1549,8 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) { DEFINE_DISPATCH(lu_stub); +// TODO: remove check_errors argument +// https://github.com/pytorch/pytorch/issues/64014 std::tuple _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) { TORCH_CHECK(self.dim() >= 2, "expected tensor with 2 or more dimensions, got size: ", self.sizes(), @@ -1566,14 +1568,6 @@ std::tuple _lu_with_info(const Tensor& self, bool comput // 'lu' tensor is modified in-place and must be a copy of 'self' Tensor lu = cloneBatchedColumnMajor(self); lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots); - - if (check_errors) { - if (self.dim() > 2) { - batchCheckErrors(infos_tensor, "lu", /*allow_singular=*/true); - } else { - singleCheckErrors(infos_tensor.item(), "lu", /*allow_singular=*/true); - } - } return std::make_tuple(lu, pivots_tensor, infos_tensor); } diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index ab522ac21ea92..eb025f47e9d76 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -3,6 +3,12 @@ #include #include #include +#include + +#if AT_MKLDNN_ENABLED() +#include +#include +#endif // AT_MKLDNN_ENABLED namespace at { namespace meta { @@ -62,6 +68,19 @@ TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tens at::native::copy_(const_cast(result), *self_); } if (result.numel() != 0) { + +#if AT_MKLDNN_ENABLED() + NoNamesGuard guard; + // mkldnn matmul expect dim >= 2 + auto vec_ = vec.unsqueeze(1); + if (use_mkldnn_bf16_gemm(mat, vec_, /*result=*/Tensor())){ + mkldnn_matmul(mat, vec_, result.unsqueeze_(1), beta_.to(), alpha_.to()); + // recover tensor's dim = 1 + result.squeeze_(1); + return; + } +#endif // AT_MKLDNN_ENABLED + auto r_stride = result.stride(0); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] { auto beta = beta_.to(); @@ -133,11 +152,35 @@ inline void dot_check(const Tensor& self, const Tensor& other) { } Tensor dot(const Tensor &self, const Tensor &other){ - at::NoNamesGuard guard; + if (self.is_complex()) { + if (self.is_conj()) { + if (other.is_conj()) { + return (at::native::dot(self.conj(), other.conj())).conj(); + } else { + return at::native::vdot(self.conj(), other); + } + } else if (other.is_conj()) { + return at::native::vdot(other.conj(), self); + } + } + at::NoNamesGuard guard; dot_check(self, other); - return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, self.scalar_type(), "dot", [&] { +#if AT_MKLDNN_ENABLED() + // mkldnn matmul expect dim >= 2 + auto self_ = self.unsqueeze(0); + auto other_= other.unsqueeze(1); + if (use_mkldnn_bf16_gemm(self_, other_, /*result=*/Tensor())){ + // mkldnn matmul expect result have sizes info to create ideep tensor + auto r = at::empty({1, 1}, self.options()); + mkldnn_matmul(self_, other_, r, /*beta=*/0); + // recovery tensor's dim = 1 + return r.squeeze_(); + } +#endif // AT_MKLDNN_ENABLED + + return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "dot", [&] { Tensor result = at::empty({}, self.options()); result.fill_(dot_impl(self.numel(), self.data_ptr(), self.stride(0), other.data_ptr(), other.stride(0))); return result; @@ -145,15 +188,25 @@ Tensor dot(const Tensor &self, const Tensor &other){ } Tensor vdot(const Tensor &self, const Tensor &other){ - at::NoNamesGuard guard; - // Dispatch to `dot` for real dtypes. if (!self.is_complex()){ return at::dot(self, other); } + if (self.is_conj()) { + if (other.is_conj()) { + return at::native::vdot(other.conj(), self.conj()); + } else { + return at::native::dot(self.conj(), other); + } + } else if (other.is_conj()) { + return (at::native::dot(self, other.conj())).conj(); + } + + at::NoNamesGuard guard; // For complex dtypes. dot_check(self, other); + return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] { Tensor result = at::empty({}, self.options()); result.fill_(vdot_impl(self.numel(), self.data_ptr(), self.stride(0), other.data_ptr(), other.stride(0))); diff --git a/aten/src/ATen/native/Bucketization.cpp b/aten/src/ATen/native/Bucketization.cpp index 7dc76a7577aa2..c11ce253f1d4a 100644 --- a/aten/src/ATen/native/Bucketization.cpp +++ b/aten/src/ATen/native/Bucketization.cpp @@ -74,12 +74,12 @@ void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tens void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right) { if (!out_int32) { - AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] { searchsorted_cpu_contiguous(result, input, boundaries, right); }); } else { - AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] { searchsorted_cpu_contiguous(result, input, boundaries, right); }); } diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 1a1f6737f23f1..f14e4dce68b5a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -78,7 +78,7 @@ char to_blas(TransposeType trans) { switch (trans) { case Transpose: return 't'; case NoTranspose: return 'n'; - // case ConjTranspose: return 'c'; + case ConjTranspose: return 'c'; } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } @@ -89,7 +89,7 @@ fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { switch (trans) { case Transpose: return fbgemm::matrix_op_t::Transpose; case NoTranspose: return fbgemm::matrix_op_t::NoTranspose; - // case ConjTranspose: return fbgemm::matrix_op_t::Transpose; + case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index e61207f7c76b8..3a483e4361bd2 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -12,7 +12,7 @@ namespace cpublas { enum TransposeType { Transpose, NoTranspose, - // ConjTranspose, -- Not implemented + ConjTranspose, }; namespace internal { diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index e1cc31df60f54..7e11b1bdd5b6f 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -136,7 +136,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); output.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 1b8538ec07601..6dc1fc7af5e5c 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -28,6 +28,7 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) { return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 && src.stride(0) == 1 && src.stride(1) == src.size(0) && self.scalar_type() == src.scalar_type() && + self.sizes().equals(src.sizes()) && self.numel() >= MIN_SZ; } @@ -45,6 +46,9 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { } Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options()); + // The code below is implemented with the assumption that sizes are equal + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes())); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, self.scalar_type(), "copy_", [&] { scalar_t* sp = src.data_ptr(); scalar_t* rp = self.data_ptr(); @@ -253,6 +257,22 @@ Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { return self; } +void copy_ignoring_overlaps(const Tensor &dst, const Tensor &src) { + // Called when we are copying into an overlapping index `dst`, but we don't + // care which writer wins. Hacky but it works. This is only used by + // CUDA_tensor_apply2 in case that there are write overlaps. + // FIXME: really, overlapping writes should be illegal/an error in Torch + auto iter = TensorIteratorConfig() + .add_output(dst) + .add_input(src) + .resize_outputs(false) + .set_check_mem_overlap(false) + .check_all_same_dtype(true) + .check_all_same_device(true) + .build(); + copy_stub(iter.device_type(), iter, /*non_blocking=*/false); +} + DEFINE_DISPATCH(copy_stub); } // namespace native diff --git a/aten/src/ATen/native/Copy.h b/aten/src/ATen/native/Copy.h index 2dfd9e9f4922b..938466102b469 100644 --- a/aten/src/ATen/native/Copy.h +++ b/aten/src/ATen/native/Copy.h @@ -13,5 +13,7 @@ using copy_fn = void (*)(TensorIterator&, bool non_blocking); DECLARE_DISPATCH(copy_fn, copy_stub); +TORCH_API void copy_ignoring_overlaps(const Tensor &dst, const Tensor &src); + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/GatedLinearUnit.cpp b/aten/src/ATen/native/GatedLinearUnit.cpp index a0e2c16ed645f..c585caa71a011 100644 --- a/aten/src/ATen/native/GatedLinearUnit.cpp +++ b/aten/src/ATen/native/GatedLinearUnit.cpp @@ -3,12 +3,11 @@ #include namespace at { -namespace native { - -DEFINE_DISPATCH(glu_stub); -DEFINE_DISPATCH(glu_backward_stub); -Tensor& glu_out(const Tensor& self, int64_t dim, Tensor &result) { +namespace meta { +TORCH_META_FUNC(glu) ( + const Tensor& self, int64_t dim +) { // this can't pass anyway because a 0-dimensional tensor has "size" 1, which // can't be evenly halved, but give a nicer error message here. TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); @@ -16,23 +15,24 @@ Tensor& glu_out(const Tensor& self, int64_t dim, Tensor &result) { const int64_t nIn = self.size(wrap_dim); TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", wrap_dim, " is size ", nIn); + // size output to half of input const int64_t selfSize = nIn / 2; - auto newSizes = self.sizes().vec(); - newSizes[wrap_dim] = selfSize; - result.resize_(newSizes); - // half tensor Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize); Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize); - - auto iter = TensorIterator::borrowing_binary_op(result, firstHalf, secondHalf); - glu_stub(iter.device_type(), iter); - return result; + build_borrowing_binary_op(maybe_get_output(), firstHalf, secondHalf); } +} // namespace meta + +namespace native { + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(glu_stub); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +DEFINE_DISPATCH(glu_backward_stub); -Tensor glu(const Tensor& self, int64_t dim) { - auto result = at::empty({0}, self.options()); - return at::glu_out(result, self, dim); +TORCH_IMPL_FUNC(glu_out) (const Tensor& self, int64_t dim, const Tensor& out) { + glu_stub(device_type(), *this); } Tensor& glu_backward_cpu_out(const Tensor& grad_output, const Tensor& input, diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index 0970095a68fa9..586b9612f80f4 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -86,7 +86,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); output.zero_(); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/Integration.cpp b/aten/src/ATen/native/Integration.cpp index 262519f69a61c..e57dc4505df4e 100644 --- a/aten/src/ATen/native/Integration.cpp +++ b/aten/src/ATen/native/Integration.cpp @@ -52,7 +52,21 @@ Tensor do_cumulative_trapezoid(const Tensor& y, double dx, int64_t dim) { return (dx /2. * (left + right)).cumsum(dim); } - +// Given the current shape of a Tensor and a target number of dimensions, +// returns a new shape with the same values as the original shape, +// but with '1's padded in the beginning to match the target number of dimensions. +// For example, curr_shape = (5,5,5) and target_n_dim = 6 ==> (1,1,1,5,5,5) +// Note that no padding will be added if the current shape has the greater than or equal +// number of dimensions than the target numbers of dimensions. +DimVector add_padding_to_shape(IntArrayRef curr_shape, int64_t target_n_dim) { + if (curr_shape.size() >= target_n_dim) + target_n_dim = curr_shape.size(); + DimVector new_shape(target_n_dim, 1); + for (decltype(curr_shape.size()) i = 0; i < curr_shape.size(); i++) { + new_shape[target_n_dim-i-1] = curr_shape[curr_shape.size()-i-1]; + } + return new_shape; +} } Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) { @@ -71,9 +85,15 @@ Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) { // Note: This behavior differs from numpy in that numpy tries to // broadcast 'dx', but this tries to broadcast 'x' to match 'y' instead. TORCH_CHECK(x.size(0) == y.size(dim), "trapezoid: There must be one `x` value for each sample point"); - DimVector sizes(y.dim(), 1); - sizes[dim] = x.size(0); - x_viewed = x.view(sizes); + DimVector new_sizes(y.dim(), 1); // shape = [1] * y. + new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0] + x_viewed = x.view(new_sizes); + } else if (x.dim() < y.dim()) { + // When 'y' has more dimension than 'x', this step takes 'x' with dimension (n_1, n_2, ...), + // and add '1's as dimensions in front to become (1, 1, ..., n_1, n_2), matching the dimension of 'y'. + // This allows the subsequent slicing operations to proceed with any 'dim' without going out of bound. + DimVector new_sizes = add_padding_to_shape(x.sizes(), y.dim()); + x_viewed = x.view(new_sizes); } else { x_viewed = x; } @@ -110,9 +130,12 @@ Tensor cumulative_trapezoid(const Tensor& y, const Tensor& x, int64_t dim) { Tensor x_viewed; if (x.dim() == 1) { TORCH_CHECK(x.size(0) == y.size(dim), "cumulative_trapezoid: There must be one `x` value for each sample point"); - DimVector sizes(y.dim(), 1); // shape = [1] * y. - sizes[dim] = x.size(0); // shape[axis] = d.shape[0] - x_viewed = x.view(sizes); + DimVector new_sizes(y.dim(), 1); // shape = [1] * y. + new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0] + x_viewed = x.view(new_sizes); + } else if (x.dim() < y.dim()) { + DimVector new_sizes = add_padding_to_shape(x.sizes(), y.dim()); + x_viewed = x.view(new_sizes); } else { x_viewed = x; } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index bbb6fce844524..0576bd667c3f6 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -23,6 +23,10 @@ #include #include +#if AT_MKLDNN_ENABLED() +#include +#include +#endif // AT_MKLDNN_ENABLED namespace at { namespace meta { @@ -959,7 +963,6 @@ Tensor outer(const Tensor& self, const Tensor& vec2) { static void addmm_impl_cpu_( Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) { TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2); - // Array access is faster than .size(n) and .stride(n) const auto self_sizes = self.sizes(); auto m1_strides = m1.strides(); @@ -992,18 +995,18 @@ static void addmm_impl_cpu_( if (result_strides[0] == 1 && (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) { transpose_c = false; - c = result; + c = result.resolve_conj(); } else if (result_strides[1] == 1 && (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) { std::swap(m1, m2); std::swap(m1_sizes, m2_sizes); std::swap(m1_strides, m2_strides); transpose_c = true; - c = result; + c = result.resolve_conj(); } else { transpose_c = false; // make c FORTRAN contiguous - c = result.transpose(0, 1).contiguous().transpose_(0, 1); + c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1); } const int64_t m = result_sizes[transpose_c ? 1 : 0]; @@ -1017,7 +1020,7 @@ static void addmm_impl_cpu_( if (m1_strides[transpose_c ? 1 : 0] == 1 && m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) { transpose_a = false; - a = m1; + a = m1.resolve_conj(); } else if (m1_strides[transpose_c ? 0 : 1] == 1 && m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) { transpose_a = true; @@ -1034,7 +1037,7 @@ static void addmm_impl_cpu_( if (m2_strides[transpose_c ? 1 : 0] == 1 && m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) { transpose_b = false; - b = m2; + b = m2.resolve_conj(); } else if (m2_strides[transpose_c ? 0 : 1] == 1 && m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) { transpose_b = true; @@ -1048,13 +1051,31 @@ static void addmm_impl_cpu_( const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0]; const int64_t ldc = c.strides()[transpose_c ? 0 : 1]; + // Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj()); + +#if AT_MKLDNN_ENABLED() + if (use_mkldnn_bf16_gemm(a, b, c)){ + if (transpose_c){ + // m1, m2 are swapped + mkldnn_matmul(b, a, c, beta.to(), alpha.to()); + } else { + mkldnn_matmul(a, b, c, beta.to(), alpha.to()); + } + if (!c.is_same(result)) { + result.copy_(c); + } + return; + } +#endif // AT_MKLDNN_ENABLED + // Apply BLAS routine AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, result.scalar_type(), "addmm_impl_cpu_", [&]{ at::native::cpublas::gemm( - transpose_a ? cpublas::Transpose : cpublas::NoTranspose, - transpose_b ? cpublas::Transpose : cpublas::NoTranspose, + transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, + transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, m, n, k, alpha.to(), a.data_ptr(), lda, @@ -1102,6 +1123,13 @@ static void addbmm_impl_( return; } +#if AT_MKLDNN_ENABLED() + if (use_mkldnn_bf16_gemm(batch1, batch2, result)){ + mkldnn_matmul(batch1, batch2, result, beta.to(), alpha.to()); + return; + } +#endif // AT_MKLDNN_ENABLED + auto adjusted_beta(beta); for (int64_t batch = 0; batch < num_batches; ++batch) { result.addmm_(batch1[batch], batch2[batch], adjusted_beta, alpha); @@ -1252,6 +1280,13 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& || (strides[1] == 1 && strides[2] >= sizes[1]); }; +#if AT_MKLDNN_ENABLED() + if (use_mkldnn_bf16_gemm(batch1, batch2, self_or_result)){ + mkldnn_matmul(batch1, batch2, self_or_result, beta.to(), alpha.to()); + return self_or_result; + } +#endif // AT_MKLDNN_ENABLED + if (contraction_size * res_rows * res_cols < 400) { if (is_bmm_out) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] { @@ -1349,8 +1384,18 @@ Tensor& baddbmm_out_cpu(const Tensor& self_, const Tensor& batch1, const Tensor& return at::native::baddbmm__cpu(result, batch1, batch2, beta, alpha); } +Tensor& conjugate_mutable_input_if_needed(Tensor& self, bool conjugate) { + if (conjugate) { + self.conj_physical_(); + } + return self; +} + Tensor& baddbmm__cpu(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { - return bmm_out_or_baddbmm_(self, batch1, batch2, beta, alpha, false); + bool self_is_conj = self.is_conj(); + conjugate_mutable_input_if_needed(self, self_is_conj); + bmm_out_or_baddbmm_(self, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, false); + return conjugate_mutable_input_if_needed(self, self_is_conj); } Tensor bmm_cpu(const Tensor& self, const Tensor& mat2) { @@ -1363,7 +1408,10 @@ Tensor& bmm_out_cpu(const Tensor& batch1, const Tensor& batch2, Tensor &result) Scalar alpha(1.0); { NoNamesGuard guard; - bmm_out_or_baddbmm_(result, batch1, batch2, beta, alpha, true); + bool result_is_conj = result.is_conj(); + conjugate_mutable_input_if_needed(result, result_is_conj); + bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, true); + conjugate_mutable_input_if_needed(result, result_is_conj); } namedinference::propagate_names_if_nonempty( result, @@ -1552,6 +1600,15 @@ Tensor& matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &resul return result; } +// torch.linalg.matmul, alias for torch.matmul +Tensor linalg_matmul(const Tensor & tensor1, const Tensor & tensor2) { + return at::native::matmul(tensor1, tensor2); +} + +Tensor& linalg_matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &result) { + return at::native::matmul_out(tensor1, tensor2, result); +} + // helper methods for matrix_exp namespace { @@ -2651,12 +2708,9 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) { shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend()); // If the reshaped self is not invertible catch this error - Tensor result; - try { - result = at::inverse(self.reshape({prod_ind_end, prod_ind_end})); - } catch (...) { - TORCH_CHECK(false, "Failed to invert the input tensor, because it is singular."); - } + Tensor result, info; + std::tie(result, info) = at::linalg_inv_ex(self.reshape({prod_ind_end, prod_ind_end}), /*check_errors=*/false); + TORCH_CHECK(info.item() == 0, "Failed to invert the input tensor, because it is singular."); return result.reshape(shape_ind_end); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 19e41d7a8e815..abbf82ceb148c 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -213,62 +213,66 @@ static inline void squareCheckInputs(const Tensor& self) { "but they are ", self.size(-1), " by ", self.size(-2), " matrices"); } +/* + * Given a info int, obtained after a single operation, this function check if the computation + * has been successful (info = 0) or not, and report in case of the latter. + */ +static inline void singleCheckErrors(int64_t info, const char* name, int64_t batch_id=-1) { + std::string batch_string{""}; + if (batch_id >= 0) { + batch_string = ": (Batch element " + std::to_string(batch_id) + ")"; + } + if (info < 0) { + TORCH_INTERNAL_ASSERT(false, name, batch_string, + ": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library."); + } else if (info > 0) { + if (strstr(name, "inv")) { + // inv, inverse, cholesky_inverse, etc. + TORCH_CHECK(false, name, batch_string, + ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular."); + } else if (strstr(name, "solve")) { + // solve, linalg_solve, cholesky_solve, etc. + TORCH_CHECK(false, name, batch_string, + ": The diagonal element ", info, " is zero, the solve could not be completed because the input matrix is singular."); + } else if (strstr(name, "cholesky")) { + TORCH_CHECK(false, name, batch_string, + ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite)."); + } else if (strstr(name, "svd")) { + TORCH_CHECK(false, name, batch_string, + ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ")."); + } else if (strstr(name, "eig") || strstr(name, "syevd")) { + TORCH_CHECK(false, name, batch_string, + ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ")."); + } else if (strstr(name, "lstsq")) { + TORCH_CHECK(false, name, batch_string, + ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); + } else { + TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, "."); + } + } +} + /* * Given a vector of int64_t infos, obtained after a batch operations, * this function checks if the computation over all these batches has been * successful (info = 0) or not, and report in case of the latter. */ -static inline void batchCheckErrors(std::vector& infos, const char* name, bool allow_singular=false) { +static inline void batchCheckErrors(const std::vector& infos, const char* name) { for (size_t i = 0; i < infos.size(); i++) { auto info = infos[i]; - if (info < 0) { - AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value"); - } else if (info > 0) { - if (strstr(name, "svd")) { - AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")"); - } else if (strstr(name, "symeig") || strstr(name, "syevd")) { - AT_ERROR(name, ": For batch ", i, ": the algorithm failed to converge; ", info, - " off-diagonal elements of an intermediate tridiagonal form did not converge to zero."); - } else if (!allow_singular) { - AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U."); - } - } + singleCheckErrors(info, name, i); } } /* * This is an overloaded case of the previous function for a tensor of infos. */ -static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false, int info_per_batch=1) { - auto batch_size = infos.numel(); +static inline void batchCheckErrors(const Tensor& infos, const char* name) { auto infos_cpu = infos.to(at::kCPU); auto infos_data = infos_cpu.data_ptr(); - for (int64_t i = 0; i < batch_size; i++) { + for (int64_t i = 0; i < infos.numel(); i++) { auto info = infos_data[i]; - if (info < 0) { - AT_ERROR(name, ": For batch ", i/info_per_batch, ": Argument ", -info, " has illegal value"); - } else if (!allow_singular && info > 0) { - AT_ERROR(name, ": For batch ", i/info_per_batch, ": U(", info, ",", info, ") is zero, singular U."); - } - } -} - -/* - * Given a info int, obtained after a single operation, this function check if the computation - * has been successful (info = 0) or not, and report in case of the latter. - */ -static inline void singleCheckErrors(int64_t info, const char* name, bool allow_singular=false) { - if (info < 0) { - AT_ERROR(name, ": Argument ", -info, " has illegal value"); - } else if (info > 0) { - if (strstr(name, "svd")) { - AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")"); - } else if (strstr(name, "eig")) { // this catches both "eig" and "symeig" - AT_ERROR(name, ": the algorithm failed to converge; ", info, - " off-diagonal elements of an intermediate tridiagonal form did not converge to zero."); - } else if (!allow_singular) { - AT_ERROR(name, ": U(", info, ",", info, ") is zero, singular U."); - } + singleCheckErrors(info, name, i); } } diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index 7c306c2bb863c..83f169972942f 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -22,10 +22,12 @@ TORCH_META_FUNC(nll_loss_forward) TORCH_CHECK( self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D"); TORCH_CHECK( - target.dim() == 1, - "1D target tensor expected, multi-target not supported"); + target.dim() <= 1, + "0D or 1D target tensor expected, multi-target not supported"); + + auto no_batch_dim = self.dim() == 1 && target.dim() == 0; TORCH_CHECK( - self.size(0) == target.size(0), + no_batch_dim || (self.size(0) == target.size(0)), "size mismatch (got input: ", self.sizes(), ", target: ", @@ -66,10 +68,12 @@ TORCH_META_FUNC(nll_loss_backward) TORCH_CHECK( self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D"); TORCH_CHECK( - target.dim() == 1, - "1D target tensor expected, multi-target not supported"); + target.dim() <= 1, + "0D or 1D target tensor expected, multi-target not supported"); + + auto no_batch_dim = self.dim() == 1 && target.dim() == 0; TORCH_CHECK( - self.size(0) == target.size(0), + no_batch_dim || (self.size(0) == target.size(0)), "size mismatch (got input: ", self.sizes(), ", target: ", @@ -181,7 +185,6 @@ static void nll_loss_out_frame( const int64_t ndim = input.dim(); TORCH_CHECK(ndim <= 2); const int64_t batch_size = ndim == 1 ? 1 : input.size(0); - TORCH_CHECK(target.size(0) == batch_size); constexpr int64_t cascade_sum_num_levels = 8; const int64_t level_power = @@ -298,7 +301,11 @@ static void nll_loss_backward_out_frame( const auto n_dims = input.dim(); const auto n_classes = input.size(-1); - auto target_acc = target.accessor(); + auto target_ = target; + if (target.dim() == 0) { + target_ = target.unsqueeze(0); + } + auto target_acc = target_.accessor(); auto weight_contiguous = optional_contiguous(weight); const scalar_t* weight_data = optional_data(weight_contiguous); @@ -349,7 +356,6 @@ static void nll_loss_backward_out_frame( auto grad_input_acc = grad_input.accessor(); const auto batch_size = input.size(0); - TORCH_CHECK(target.size(0) == batch_size); for (int64_t i = 0; i < batch_size; i++) { const auto cur_target = target_acc[i]; @@ -453,9 +459,10 @@ TORCH_IMPL_FUNC(nll_loss_backward_out_cpu) Tensor cross_entropy_loss_prob_target( const Tensor& self, - const Tensor& target, + const Tensor& target_, const Tensor& weight, - int64_t reduction) { + int64_t reduction, + double label_smoothing) { const auto n_classes = self.size(1); TORCH_CHECK( !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes), @@ -466,6 +473,15 @@ Tensor cross_entropy_loss_prob_target( weight.sizes()); auto input = at::log_softmax(self, 1, self.scalar_type()); + Tensor target; + + if (label_smoothing > 0.0) { + TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing); + target = target_ * (1 - label_smoothing) + label_smoothing / n_classes; + } else { + target = target_; + } + if (weight.defined()) { // Expand weight to the correct number of dims for broadcasting with input / target auto weight_broadcast_shape = SmallBuffer(input.dim()); @@ -497,12 +513,66 @@ Tensor cross_entropy_loss_prob_target( } } +Tensor cross_entropy_loss_label_smoothing( + const Tensor& self, + const Tensor& target, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index, + double label_smoothing) { + + auto input = at::log_softmax(self, 1, self.scalar_type()); + auto nllloss = at::nll_loss_nd(input, target, weight, reduction, ignore_index); + + auto n_classes = input.size(1); + + Tensor smooth_loss; + if (weight.defined()) { + // Expand weight to the correct number of dims for broadcasting with input / target + auto weight_broadcast_shape = SmallBuffer(input.dim()); + std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1); + weight_broadcast_shape[1] = weight.size(0); + Tensor weight_ = weight.view(weight_broadcast_shape); + + smooth_loss = -(input * weight_).sum(1); + } else { + smooth_loss = -input.sum(1); + } + + if (ignore_index >= 0) { + smooth_loss.index_put_({target == ignore_index}, 0.0); + } + + Tensor ret; + switch (reduction) { + case Reduction::Mean: + if (weight.defined()) { + // TODO: This code can path can be removed if #61309 is resolved + // loss is normalized by the weights to be consistent with nll_loss_nd + ret = smooth_loss.sum() / weight.gather(0, target.flatten()).sum(); + } else { + ret = smooth_loss.mean(); + } + break; + case Reduction::Sum: + ret = smooth_loss.sum(); + break; + case Reduction::None: + ret = smooth_loss; + break; + default: + TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction); + } + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes); +} + Tensor cross_entropy_loss( const Tensor& self, const Tensor& target, const c10::optional& weight, int64_t reduction, - int64_t ignore_index) { + int64_t ignore_index, + double label_smoothing) { Tensor ret; if (self.sizes() == target.sizes()) { // Assume soft targets when input and target shapes are the same @@ -513,7 +583,14 @@ Tensor cross_entropy_loss( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight); const Tensor& weight_ = *weight_maybe_owned; - ret = cross_entropy_loss_prob_target(self, target, weight_, reduction); + ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing); + } else if (label_smoothing > 0.0) { + TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing); + + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight); + const Tensor& weight_ = *weight_maybe_owned; + ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, ignore_index, label_smoothing); } else { ret = at::nll_loss_nd( at::log_softmax(self, 1, self.scalar_type()), @@ -548,12 +625,12 @@ Tensor nll_loss_nd( const c10::optional& weight, int64_t reduction, int64_t ignore_index) { - if (self.dim() < 2) { + if (self.dim() < 1) { TORCH_CHECK_VALUE( - false, "Expected 2 or more dimensions (got ", self.dim(), ")"); + false, "Expected 1 or more dimensions (got ", self.dim(), ")"); } - if (self.sizes()[0] != target.sizes()[0]) { + if (self.dim() != 1 && self.sizes()[0] != target.sizes()[0]) { TORCH_CHECK_VALUE( false, "Expected input batch_size (", @@ -566,7 +643,7 @@ Tensor nll_loss_nd( Tensor ret; Tensor input_ = self; Tensor target_ = target; - if (input_.dim() == 2) { + if (input_.dim() == 1 || input_.dim() == 2) { ret = at::nll_loss(input_, target_, weight, reduction, ignore_index); } else if (input_.dim() == 4) { ret = at::nll_loss2d(input_, target_, weight, reduction, ignore_index); diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp index b3c01941c73de..99874084470f4 100644 --- a/aten/src/ATen/native/MaxUnpooling.cpp +++ b/aten/src/ATen/native/MaxUnpooling.cpp @@ -1,90 +1,17 @@ #include #include -#include -#include +#include namespace at { namespace native { -template -Tensor max_unpooling2d_forward_out_cpu_frame( - Tensor& output, - const Tensor& input, - const Tensor& indices, - int64_t oheight, - int64_t owidth) { - int64_t numBatch = 1; - int64_t dimc = 0; - int64_t dimh = 1; - int64_t dimw = 2; - if (input.ndimension() == 4) { - numBatch = input.size(0); - dimc++; - dimh++; - dimw++; - } - int64_t numChannels = input.size(dimc); - int64_t inputHeight = input.size(dimh); - int64_t inputWidth = input.size(dimw); - - auto* rawInput = input.data_ptr(); - auto* rawIndices = indices.data_ptr(); - auto* rawOutput = output.data_ptr(); - - at::internal::lazy_init_num_threads(); - - for (int64_t n = 0; n < numBatch; n++) { - int64_t nOutputOffset = n * numChannels * owidth * oheight; - int64_t nInputOffset = n * numChannels * inputWidth * inputHeight; - int64_t k = 0; - bool has_error = false; - int64_t error_index = 0; -#pragma omp parallel for private(k) - for (k = 0; k < numChannels; k++) { - int64_t finalOutputOffset = nOutputOffset + k * owidth * oheight; - int64_t finalInputOffset = nInputOffset + k * inputWidth * inputHeight; - scalar_t* output_p_k = rawOutput + finalOutputOffset; - scalar_t* input_p_k = rawInput + finalInputOffset; - int64_t* ind_p_k = rawIndices + finalInputOffset; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - for (int64_t i = 0; i < inputHeight; i++) { - for (int64_t j = 0; j < inputWidth; j++) { - maxp = ind_p_k[i * inputWidth + j]; - if (maxp < 0 || maxp >= owidth * oheight) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } else { - output_p_k[maxp] = input_p_k[i * inputWidth + j]; - } - } - } - } - if (has_error) { - AT_ERROR( - "Found an invalid max index: ", - error_index, - " (output volumes are of size ", - oheight, - "x", - owidth); - (void)error_index; - } - } - return output; -} - -Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, +Tensor& max_unpooling2d_forward_out_cpu( + const Tensor& self_, const Tensor& indices_, IntArrayRef output_size, Tensor& output) { auto oheight = output_size[0]; auto owidth = output_size[1]; - TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); TORCH_CHECK( indices_.scalar_type() == at::ScalarType::Long, "elements in indices should be type int64"); @@ -100,8 +27,9 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, TORCH_CHECK(self_.numel() > 0, "Input must be non-empty"); - auto self = self_.contiguous(); - auto indices = indices_.contiguous(); + auto memory_format = self_.suggest_memory_format(); + auto self = self_.contiguous(memory_format); + auto indices = indices_.contiguous(memory_format); if (self.ndimension() == 3) { int64_t numChannels = self.size(0); @@ -109,15 +37,11 @@ Tensor& max_unpooling2d_forward_out_cpu(const Tensor& self_, } else { int64_t numBatch = self.size(0); int64_t numChannels = self.size(1); - output.resize_({numBatch, numChannels, oheight, owidth}); + output.resize_({numBatch, numChannels, oheight, owidth}, memory_format); } output.zero_(); - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling2d_forward_out_cpu_frame", ([&] { - max_unpooling2d_forward_out_cpu_frame( - output, self, indices, oheight, owidth); - })); + max_unpool2d_kernel(kCPU, output, self, indices); return output; }; @@ -130,87 +54,6 @@ Tensor max_unpooling2d_forward_cpu( return output; } -template -Tensor max_unpooling3d_forward_out_cpu_frame( - Tensor& output, - const Tensor& input, - const Tensor& indices, - int64_t oT, - int64_t oH, - int64_t oW) { - int64_t nBatch = 1; - int64_t dimw = 3; - int64_t dimh = 2; - int64_t dimt = 1; - - if (input.ndimension() == 5) { - nBatch = input.size(0); - dimw++; - dimh++; - dimt++; - } - - int64_t nSlices = input.size(dimt - 1); - int64_t iT = input.size(dimt); - int64_t iH = input.size(dimh); - int64_t iW = input.size(dimw); - - scalar_t* input_data = input.data_ptr(); - scalar_t* output_data = output.data_ptr(); - int64_t* indices_data = indices.data_ptr(); - - at::internal::lazy_init_num_threads(); - - for (int64_t p = 0; p < nBatch; p++) { - int64_t inputOffset = p * nSlices * iT * iW * iH; - int64_t outputOffset = p * nSlices * oT * oW * oH; - int64_t k = 0; - bool has_error = false; - int error_index = 0; -#pragma omp parallel for private(k) - for (k = 0; k < nSlices; k++) { - int64_t finalInputOffset = inputOffset + k * iT * iW * iH; - int64_t finalOutputOffset = outputOffset + k * oT * oW * oH; - - scalar_t* output_p_k = output_data + finalOutputOffset; - scalar_t* input_p_k = input_data + finalInputOffset; - int64_t* ind_p_k = indices_data + finalInputOffset; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int maxp; - for (int64_t t = 0; t < iT; t++) { - for (int64_t i = 0; i < iH; i++) { - for (int64_t j = 0; j < iW; j++) { - int64_t index = t * iH * iW + i * iW + j; - maxp = ind_p_k[index]; - if (maxp < 0 || maxp >= oT * oW * oH) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } else { - output_p_k[maxp] = input_p_k[index]; - } - } - } - } - if (has_error) { - AT_ERROR( - "found an invalid max index ", - error_index, - " (output volumes are of size ", - oT, - "x", - oH, - "x", - oW); - (void)error_index; - } - } - } - return output; -} - static void max_unpooling3d_shape_check( const Tensor& input, const Tensor& gradOutput, @@ -310,16 +153,7 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_, } output.zero_(); - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling3d_forward_out_cpu_frame", ([&] { - max_unpooling3d_forward_out_cpu_frame( - output, - self, - indices, - oT, - oH, - oW); - })); + max_unpool3d_kernel(kCPU, output, self, indices); return output; } @@ -335,59 +169,6 @@ Tensor max_unpooling3d_forward_cpu( return output; } -template -static void max_unpooling2d_backward_out_cpu_frame( - scalar_t* gradInput_p, - scalar_t* gradOutput_p, - int64_t* ind_p, - int64_t nslices, - int64_t iheight, - int64_t iwidth, - int64_t oheight, - int64_t owidth) { - bool has_error = false; - int64_t error_index = 0; - int64_t k = 0; - - at::internal::lazy_init_num_threads(); -#pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) { - scalar_t* gradInput_p_k = gradInput_p + k * iwidth * iheight; - scalar_t* gradOutput_p_k = gradOutput_p + k * owidth * oheight; - int64_t* ind_p_k = ind_p + k * iwidth * iheight; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i, j; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - - for (i = 0; i < iheight; i++) { - for (j = 0; j < iwidth; j++) { - maxp = ind_p_k[i * iwidth + j]; /* retrieve position of max */ - if (maxp < 0 || maxp >= owidth * oheight) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } - gradInput_p_k[i * iwidth + j] = - gradOutput_p_k[maxp]; /* update gradient */ - } - } - } - if (has_error) { - AT_ERROR( - "invalid max index ", - error_index, - ", owidth= ", - owidth, - ", oheight= ", - oheight); - (void)error_index; - } -} - Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, const Tensor& self, const Tensor& indices_, @@ -396,42 +177,24 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); int64_t oheight = output_size[0]; int64_t owidth = output_size[1]; - int dimw = 2; - int dimh = 1; - int nbatch = 1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int nslices; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iheight; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iwidth; + int64_t ndim = self.ndimension(); + int64_t dimh = ndim == 3 ? 1 : 2; + int64_t dimw = ndim == 3 ? 2 : 3; + TORCH_CHECK( indices_.scalar_type() == at::ScalarType::Long, "elements in indices should be type int64"); TORCH_CHECK( self.sizes() == indices_.sizes(), "Input shape must match indices shape"); - TORCH_CHECK(output_size.size() == 2, "Output size must be 2"); - /* get contiguous gradOutput and indices */ - auto grad_output = grad_output_.contiguous(); - auto indices = indices_.contiguous(); + auto memory_format = self.suggest_memory_format(); + auto grad_output = grad_output_.contiguous(memory_format); + auto indices = indices_.contiguous(memory_format); - /* resize */ - grad_input.resize_as_(self); + grad_input.resize_(self.sizes(), memory_format); grad_input.zero_(); - if (self.ndimension() == 4) { - nbatch = self.size(0); - dimw++; - dimh++; - } - - /* sizes */ - nslices = self.size(dimh - 1); - iheight = self.size(dimh); - iwidth = self.size(dimw); - if (owidth != grad_output.size(dimw) || oheight != grad_output.size(dimh)) { AT_ERROR( "Inconsistent gradOutput size. output height = ", @@ -443,23 +206,8 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_, "x", grad_output.size(dimw)); } - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling2d_backward_out_cpu_frame", ([&] { - int p; - for (p = 0; p < nbatch; p++) { - auto inputOffset = p * nslices * iheight * iwidth; - auto outputOffset = p * nslices * oheight * owidth; - max_unpooling2d_backward_out_cpu_frame( - grad_input.data_ptr() + inputOffset, - grad_output.data_ptr() + outputOffset, - indices.data_ptr() + inputOffset, - nslices, - iheight, - iwidth, - oheight, - owidth); - } - })); + + max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices); return grad_input; } @@ -468,72 +216,14 @@ Tensor max_unpooling2d_backward_cpu( const Tensor& self, const Tensor& indices, IntArrayRef output_size) { - auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - at::native::max_unpooling2d_backward_out_cpu( + auto grad_input = at::empty({0}, self.options()); + max_unpooling2d_backward_out_cpu( grad_output, self, indices, output_size, grad_input); return grad_input; } -template -static void max_unpooling3d_backward_out_cpu_frame( - scalar_t* gradInput_p, - scalar_t* gradOutput_p, - int64_t* ind_p, - int64_t nslices, - int64_t iT, - int64_t iH, - int64_t iW, - int64_t oT, - int64_t oH, - int64_t oW) { - int64_t k = 0; - bool has_error = false; - int error_index = 0; - - at::internal::lazy_init_num_threads(); - -#pragma omp parallel for private(k) - for (k = 0; k < nslices; k++) { - scalar_t* gradInput_p_k = gradInput_p + k * iT * iH * iW; - scalar_t* gradOutput_p_k = gradOutput_p + k * oT * oH * oW; - int64_t* ind_p_k = ind_p + k * iT * iH * iW; - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t t, i, j, index; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t maxp; - for (t = 0; t < iT; t++) { - for (i = 0; i < iH; i++) { - for (j = 0; j < iW; j++) { - index = t * iH * iW + i * iW + j; - maxp = ind_p_k[index]; /* retrieve position of max */ - if (maxp < 0 || maxp >= oT * oH * oW) { -#pragma omp critical - { - has_error = true; - error_index = maxp; - } - } - gradInput_p_k[index] = gradOutput_p_k[maxp]; /* update gradient */ - } - } - } - } - if (has_error) { - AT_ERROR( - "invalid max index ", - error_index, - ", oT= ", - oT, - ", oW= ", - oW, - ",oH= ", - oH); - (void)error_index; - } -} - -Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, +Tensor& max_unpooling3d_backward_out_cpu( + const Tensor& grad_output_, const Tensor& self, const Tensor& indices_, IntArrayRef output_size, @@ -541,26 +231,17 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, IntArrayRef padding, Tensor& grad_input) { TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); - auto oT = output_size[0]; - auto oH = output_size[1]; - auto oW = output_size[2]; - int dimw = 3; - int dimh = 2; - int dimt = 1; - int nbatch = 1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int nslices; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iT; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iH; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int iW; + int64_t oT = output_size[0]; + int64_t oH = output_size[1]; + int64_t oW = output_size[2]; + int64_t ndim = self.ndimension(); + int64_t dimt = ndim == 4 ? 1 : 2; + int64_t dimh = ndim == 4 ? 2 : 3; + int64_t dimw = ndim == 4 ? 3 : 4; max_unpooling3d_shape_check( self, grad_output_, indices_, output_size, stride, padding); - // TODO (from THNN): check gradOutput shape /* get contiguous gradOutput */ auto grad_output = grad_output_.contiguous(); auto indices = indices_.contiguous(); @@ -568,39 +249,24 @@ Tensor& max_unpooling3d_backward_out_cpu(const Tensor& grad_output_, /* resize */ grad_input.resize_as_(self); grad_input.zero_(); - if (self.ndimension() == 5) { - nbatch = self.size(0); - dimt++; - dimw++; - dimh++; + + if (oW != grad_output.size(dimw) || oH != grad_output.size(dimh) || oT != grad_output.size(dimt)) { + AT_ERROR( + "Inconsistent gradOutput size. output depth = ", + oT, + ", output height = ", + oH, + ", output width = ", + oW, + ", gradOutput: ", + grad_output.size(dimt), + "x", + grad_output.size(dimh), + "x", + grad_output.size(dimw)); } - /* sizes */ - nslices = self.size(dimt - 1); - iT = self.size(dimt); - iH = self.size(dimh); - iW = self.size(dimw); - - /* backprop */ - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "max_unpooling3d_backward_out_cpu_frame", ([&] { - int p; - for (p = 0; p < nbatch; p++) { - int inputOffset = p * nslices * iT * iH * iW; - int outputOffset = p * nslices * oT * oH * oW; - max_unpooling3d_backward_out_cpu_frame( - grad_input.data_ptr() + inputOffset, - grad_output.data_ptr() + outputOffset, - indices.data_ptr() + inputOffset, - nslices, - iT, - iH, - iW, - oT, - oH, - oW); - } - })); + max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices); return grad_input; } @@ -611,10 +277,16 @@ Tensor max_unpooling3d_backward_cpu( IntArrayRef output_size, IntArrayRef stride, IntArrayRef padding) { - auto grad_input = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_input = at::empty({0}, self.options()); at::native::max_unpooling3d_backward_out_cpu( grad_output, self, indices, output_size, stride, padding, grad_input); return grad_input; } + +DEFINE_DISPATCH(max_unpool2d_kernel); +DEFINE_DISPATCH(max_unpool2d_backward_kernel); +DEFINE_DISPATCH(max_unpool3d_kernel); +DEFINE_DISPATCH(max_unpool3d_backward_kernel); + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/NegateFallback.cpp b/aten/src/ATen/native/NegateFallback.cpp index 86dbe05ff904f..d8381f58d036b 100644 --- a/aten/src/ATen/native/NegateFallback.cpp +++ b/aten/src/ATen/native/NegateFallback.cpp @@ -55,6 +55,7 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) { m.impl("view", torch::CppFunction::makeFallthrough()); m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); m.impl("reshape", torch::CppFunction::makeFallthrough()); + m.impl("alias", torch::CppFunction::makeFallthrough()); } } // namespace at diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 40ee1d5d4a152..25ae1a765e85f 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -74,6 +74,13 @@ static inline bool is_contiguous(const Tensor& t) { return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast); } +// For some ambiguous cases, it is possible a channels last contiguous Tensor has +// `suggest_memory_format` of Contiguous. +// See https://github.com/pytorch/pytorch/issues/63224 for details. +static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) { + return t.is_contiguous() ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; +} + template std::tuple batch_norm_cpu_transform_input_template( const Tensor& input, const Tensor& weight, const Tensor& bias, @@ -87,10 +94,9 @@ std::tuple batch_norm_cpu_transform_input_template( && running_mean.is_contiguous() && running_var.is_contiguous(); - Tensor output = at::empty_like(input, input.suggest_memory_format()); - // inference contiguous path if (all_contiguous) { + Tensor output = at::empty_like(input, suggest_memory_format_contig(input)); batch_norm_cpu_stub(kCPU, output, input, weight, bias, save_mean, save_invstd, running_mean, running_var, train, eps); return std::make_tuple(output, save_mean, save_invstd); @@ -120,6 +126,7 @@ std::tuple batch_norm_cpu_transform_input_template( auto b = bias.defined() ? as_nd(bias) : at::detail::scalar_tensor_static(0, input.scalar_type(), kCPU); + Tensor output = at::empty_like(input, input.suggest_memory_format()); auto iter = TensorIteratorConfig() .add_output(output) .add_input(input) @@ -240,7 +247,7 @@ std::tuple batch_norm_backward_cpu_template( grad_weight = at::empty_like(weight, at::MemoryFormat::Contiguous); } if (grad_input_mask[2]) { - grad_bias = at::empty_like(weight, at::MemoryFormat::Contiguous); + grad_bias = at::empty({input.size(1)}, input.options()); } // since we are directly manipulating pointers in contiguous path, @@ -250,6 +257,9 @@ std::tuple batch_norm_backward_cpu_template( && input.suggest_memory_format() == grad_out_.suggest_memory_format(); if (all_contiguous) { + if (grad_input_mask[0]) { + grad_input = at::empty_like(input, suggest_memory_format_contig(input)); + } batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias, grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); return std::make_tuple(grad_input, grad_weight, grad_bias); @@ -416,6 +426,22 @@ std::tuple _batch_norm_impl_index( const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); auto num_features = input.sizes()[1]; + + if (input.numel() == 0) { + Tensor reserve = at::empty({0}, input.options().dtype(kByte)); + auto options = input.options().dtype( + at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda())); + auto save_mean = at::empty({num_features}, options); + auto save_invstd = at::empty({num_features}, options); + + // don't return view of input, don't return empty tensor because it will break gradient chain + auto out = input.clone(); + if (weight.defined()) out = out * weight[0]; + if (bias.defined()) out = out + bias[0]; + return std::tuple( + out, save_mean, save_invstd, reserve, 0); + } + if (running_mean.defined()) { check_dims_match_num_input_features("running_mean", num_features, running_mean.numel()); } else if (!training) { @@ -508,7 +534,30 @@ std::tuple _batch_norm_impl_index_backward( const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();}); - if (impl_index == 0) { + if (input.numel() == 0) { + std::vector dims(input.dim() - 1); + dims[0] = 0; + std::iota(dims.begin() + 1, dims.end(), 2); + + // don't return empty tensor because it will break gradient chain + Tensor grad_input; + Tensor grad_weight; + Tensor grad_bias; + if (output_mask[2]) { + grad_bias = grad_output.sum(dims); + } + if (output_mask[1]) { + grad_weight = (grad_output * input).sum(dims); + } + if (output_mask[0] && weight.defined()) { + grad_input = grad_output * weight[0]; + } + return std::make_tuple(grad_input, grad_weight, grad_bias); + } + + // backward in inference mode is not supported in cudnn, fallback to native + // TODO: verify the same thing in miopen + if (impl_index == 0 || (!train)) { return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask); } else if (impl_index == 1) { // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC @@ -528,13 +577,6 @@ Tensor batch_norm( const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - if (input.numel()==0){ - //don't return view of input, don't return empty tensor because it will break gradient chain - auto out = input.clone(); - if (weight.defined()) out = out * weight[0]; - if (bias.defined()) out = out + bias[0]; - return out; - } return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)); } @@ -602,7 +644,9 @@ std::tuple batch_norm_cpu(const Tensor& self, const c10: return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] { if (!train) { - return batch_norm_cpu_transform_input_template(self, weight, bias, {}, {}, running_mean, running_var, train, eps); + auto save_mean = at::empty({0}, self.options()); + auto save_var = at::empty({0}, self.options()); + return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps); } else { auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps); return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps); diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 5fe979df2c953..da774911b5737 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -16,10 +16,13 @@ DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel); DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel); // averge pooling has same signature for forward and backward -using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, +using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, + int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional divisor_override); +using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional divisor_override); + DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel); -DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_backward_kernel); +DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel); namespace { diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 8d2f2de7367d7..2c77fb348228b 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -7,8 +7,8 @@ Like all ATen methods/functions, native functions are made available from both ATen's C++ and Python APIs. In C++, they are made available either as methods on `Tensor` (`t.mymeth()`) and functions in the ATen namespace (`at::myfunc()`). In PyTorch, they are made available as -methods on `Variable` or as functions on `torch._C._FunctionBase` -(it is the user's responsibility to re-exporting these functions in +methods on `Variable` or as functions on `torch._C._FunctionBase`. +(It is the user's responsibility to re-export these functions in a more user-facing module.) The rest of this document describes how to implement an ATen function. diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 508c157965edc..7d48c63b755ce 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -113,7 +113,7 @@ Tensor& logspace_cpu_out(const Scalar& start, const Scalar& end, c10::optional; auto xstart = start.to(); auto xend = end.to(); @@ -133,7 +133,7 @@ Tensor& range_cpu_out(const Scalar& start, const Scalar& end, const Scalar& step scalar_t *data_ptr = r.data_ptr(); at::parallel_for(0, size, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) { - scalar_t is = p_begin; + accscalar_t is = p_begin; for (int64_t i = p_begin; i < p_end; ++i, ++is) { data_ptr[i] = xstart + is * xstep; } diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 6e5a1532bd8d1..620908b5b79bf 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -109,16 +109,18 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) { } } -TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) { +TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) { check_all_any("all", self, maybe_get_output()); auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output()); resize_reduction(*this, self, dim, keepdim, out_dtype); + return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim())); } -TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) { +TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) { check_all_any("any", self, maybe_get_output()); auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output()); resize_reduction(*this, self, dim, keepdim, out_dtype); + return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim())); } void check_argmax_argmin( @@ -1338,7 +1340,6 @@ Tensor all(const Tensor& self) { TORCH_IMPL_FUNC(all_out) (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) { - dim = maybe_wrap_dim(dim, self.dim()); auto iter = get_allany_iter(self, result, dim, keepdim); auto mut_result = const_cast(result); if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) { @@ -1370,8 +1371,10 @@ Tensor any(const Tensor& self) { } TORCH_IMPL_FUNC(any_out) -(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) { - dim = maybe_wrap_dim(dim, self.dim()); +(const Tensor& self, + int64_t dim, + bool keepdim, + const Tensor& result) { auto iter = get_allany_iter(self, result, dim, keepdim); auto mut_result = const_cast(result); if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) { diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index f4bff473d2333..1937a8b3d545a 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -8,7 +8,7 @@ namespace at { namespace native { // Returns true if resize is necessary bool resize_output_check(const Tensor& output, IntArrayRef shape) { - // Tests for resizing of tensors with one more elements + // Tests for resizing of tensors with one or more elements if (output.sizes().equals(shape)) { return false; } diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 5e391a0ce7571..6fb52bc0803ac 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -10,7 +10,10 @@ namespace at { namespace native { // TODO: make all operations that resize given outputs use this function -// for consistency and maintainability +// for consistency and maintainability. +// Some operations like `cat` might not be able to make the use of +// resize_output directly. For more details to understand how it works in `cat`, +// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362 // Resizes outputs // Functions accepting output tensors, like with the "out" kwarg, should // call this function to handle resizing their output tensor. @@ -20,6 +23,9 @@ namespace at { namespace native { // Returns a bool saying whether or not the resize actually happened or not TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape); +// Utility for resize_output +// Returns a bool saying resize should happen or not and +// raises a warning if resizing for one or more elements TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape); TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes); diff --git a/aten/src/ATen/native/ScatterGatherChecks.h b/aten/src/ATen/native/ScatterGatherChecks.h index ad3b3fca097ca..0fc38d5bd7418 100644 --- a/aten/src/ATen/native/ScatterGatherChecks.h +++ b/aten/src/ATen/native/ScatterGatherChecks.h @@ -9,7 +9,7 @@ namespace at { namespace native { namespace { // checks whether index.dtype == int64 -// and self.dtyp == src.dtype if src is a Tensor +// and self.dtype == src.dtype if src is a Tensor static void scatter_gather_dtype_check( const std::string& method_name, const Tensor& self, @@ -31,42 +31,31 @@ static void scatter_gather_dtype_check( } // Used for `gather`-like methods +// Note: self means the input tensor here // Test: -// 1. index.size(d) == self.size(d) for all d != dim -// 2. index.size(d) <= src.size(d) for all d != dim -// 3. index.dim() == self.dim() == src.dim() +// 1. index.size(d) <= self.size(d) for all d != dim +// 2. index.dim() == self.dim() static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src + const Tensor& index ) { auto self_dims = ensure_nonempty_dim(self.dim()); TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()), - "Index tensor must have the same number of dimensions as out tensor" - ); - - auto src_dims = ensure_nonempty_dim(src.dim()); - TORCH_CHECK(src_dims == ensure_nonempty_dim(index.dim()), "Index tensor must have the same number of dimensions as input tensor" ); for (int64_t i = 0; i < self_dims; ++i) { if (i != dim) { TORCH_CHECK( - ensure_nonempty_size(index, i) == ensure_nonempty_size(self, i), - "Size does not match at dimension ", i, - " get ", ensure_nonempty_size(self, i), - " vs ", ensure_nonempty_size(index, i) - ); - - TORCH_CHECK( - ensure_nonempty_size(index, i) <= ensure_nonempty_size(src, i), + ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), "Size does not match at dimension ", i, " expected index ", index.sizes(), - " to be smaller than src ", src.sizes(), + " to be smaller than self ", self.sizes(), " apart from dimension ", dim ); } } } + // Used for `scatter` and `scatter_add` // Tests: // 1. index.size(d) <= self.size(d) for all d != dim @@ -76,10 +65,7 @@ static C10_UNUSED void scatter_shape_check( const Tensor& self, int64_t dim, const Tensor& index, const c10::optional& src_opt = c10::nullopt ) { - if (index.numel() == 0) { - return; - } - + if (index.numel() == 0) return; TORCH_CHECK( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), "Index tensor must have the same number of dimensions as self tensor" diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index cd042073794c3..f9472b1f3dd3d 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -920,7 +920,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho // We need to trim the front padding away if centered const auto start = center ? n_fft / 2 : 0; - const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2; + const auto end = lengthOpt.has_value() ? start + lengthOpt.value() : (center ? - n_fft / 2 : -1); y = y.slice(2, start, end, 1); window_envelop = window_envelop.slice(2, start, end, 1); @@ -935,6 +935,14 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho if (input_dim == 3) { y = y.squeeze(0); } + // zero padding if the given lengthOpt is longer than expected + if(end > expected_output_signal_len) { + TORCH_WARN_ONCE( + "The length of signal is shorter than the length parameter. Result is being padded with zeros in the tail. " + "Please check your center and hop_length settings." + ); + y = at::constant_pad_nd(y, {0, end - expected_output_signal_len}, 0); + } return y; #undef REPR diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 43cebba51b9e7..3fb38cc8832ec 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -83,6 +83,31 @@ native::SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce) { } } +TORCH_META_FUNC(gather) +(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { + const Tensor& result = maybe_get_output(0); + int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim()); + + // Memory overlap checks need to be done after resizing (if required) is done. + // But it only makes sense to do these checks when result was defined, hence + // the boolean variable `check_result` here. + // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832 + // and https://github.com/pytorch/pytorch/issues/63837 + bool check_result = result.defined(); + set_output(index.sizes(), self.options()); + if (check_result) { + at::assert_no_internal_overlap(result); + at::assert_no_overlap(result, self); + at::assert_no_partial_overlap(result, index); + } + + TORCH_CHECK( + index.scalar_type() == at::ScalarType::Long, + "gather", "(): Expected dtype int64 for index" + ); + at::native::gather_shape_check(self, wrapped_dim, index); +} + template void scatter_meta_impl( Meta& meta, @@ -1112,23 +1137,12 @@ Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source); } -Tensor& gather_out_cpu_cuda( - const Tensor& self, - int64_t dim, - const Tensor& index, - bool sparse_grad, - Tensor& result) { - at::native::resize_output(result, index.sizes()); - at::assert_no_internal_overlap(result); - at::assert_no_overlap(result, self); - at::assert_no_partial_overlap(result, index); +// gather_out_cpu_cuda +TORCH_IMPL_FUNC(gather_out) +(const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) { + if (index.numel() == 0) return; + dim = at::maybe_wrap_dim(dim, self.dim()); gather_stub(result.device().type(), result, self, dim, index); - return result; -} - -Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) { - Tensor result = at::empty({0}, self.options()); - return at::native::gather_out_cpu_cuda(self, dim, index, sparse_grad, result); } Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) { @@ -1148,6 +1162,8 @@ void scatter_impl( ReduceStub& reduce_stub, FillStub& fill_stub, const c10::optional reduce = nullopt) { + if (index.numel() == 0) return; + dim = at::maybe_wrap_dim(dim, self.dim()); auto mut_out = const_cast(out); if (!self.is_same(mut_out)) { @@ -1217,11 +1233,14 @@ TORCH_IMPL_FUNC(scatter_add) const Tensor& src, const Tensor& out) { auto mut_out = const_cast(out); + dim = maybe_wrap_dim(dim, self.dim()); if (!self.is_same(mut_out)) { mut_out.copy_(self); } + if (index.numel() == 0) return; + if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA && self.dim() == 1) { TORCH_CHECK(index.dim() == 1 && src.dim() == 1, "index and src should be 1D tensors when self is a 1D tensor, " "but their dims are ", index.dim(), " and ", src.dim(), ", respectively"); diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h index cd2835aa8139b..d8271a8355ded 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.h +++ b/aten/src/ATen/native/TensorAdvancedIndexing.h @@ -24,13 +24,13 @@ using take_fn = void(*)(TensorIterator & iter, const Tensor& input); using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); using masked_scatter_fn = void(*)(TensorIterator &, const Tensor &); -using gather_fn = void (*)(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index); -using scatter_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); -using scatter_fill_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Scalar& src); -using scatter_add_fn = void(*)(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); -using scatter_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index, +using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index); +using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); +using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src); +using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); +using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, const Tensor& src, const SCATTER_GATHER_OP& reduce); -using scatter_scalar_reduce_fn = void(*)(Tensor& self, const int64_t dim, const Tensor& index, +using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, const Scalar& value, const SCATTER_GATHER_OP& reduce); DECLARE_DISPATCH(index_fn, index_stub); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 90a57d1d30c94..3f69cab48b090 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -108,8 +108,6 @@ bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, // https://github.com/numpy/numpy/issues/15959 is resolved Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) { TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type()); - TORCH_CHECK(!(self.is_complex() && equal_nan), - "isclose with equal_nan=True is not supported for complex inputs."); TORCH_CHECK(!(self.is_quantized() || other.is_quantized()), "isclose is not supported for quantized inputs."); @@ -121,8 +119,8 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol // Computes equality closeness Tensor close = self == other; - if (equal_nan && self.is_floating_point()) { - close.__ior__((self != self).__iand__(other != other)); + if (equal_nan && (self.is_floating_point() || self.is_complex())) { + close.__ior__(self.isnan().__iand__(other.isnan())); } // In case of zero tolerances the closeness inequality degenerates to an equality check. diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 3ee909be029ff..4712c3d99b6d8 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1411,17 +1411,18 @@ Tensor from_file(c10::string_view filename, c10::optional shared, c10::opt Tensor clone(const Tensor& src, c10::optional optional_memory_format) { auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); + Tensor self; if (memory_format == MemoryFormat::Preserve) { if (src.is_non_overlapping_and_dense()) { - // Copy all strides - auto self = at::empty_strided(src.sizes(), src.strides(), src.options()); - self.copy_(src); - return self; + // Copy all strides, this is marginally faster than calling empty_like + self = at::empty_strided(src.sizes(), src.strides(), src.options()); } else { - memory_format = src.suggest_memory_format(); + self = at::empty_like(src); } + } else { + self = at::empty_like(src, src.options(), memory_format); } - auto self = at::empty_like(src, src.options(), memory_format); + self.copy_(src); return self; } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index e915078249171..8f397862687ba 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -193,7 +192,10 @@ Tensor & _cat_out_cpu(TensorList tensors, int64_t dim, Tensor& result) { result_size[dim] = cat_dim_size; // skip resizing if size of result is same as expected - if (result.sizes() != result_size) { + // raise a warning while resizing if output has one or more elements + // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362 + // for understanding why at::native::resize_output is not called directly. + if (at::native::resize_output_check(result, result_size)) { result.resize_(result_size, first_tensor_mem_format); } @@ -301,6 +303,23 @@ Tensor cat(TensorList tensors, Dimname dim) { return at::cat(tensors, dimname_to_position(tensors[0], dim)); } +// torch.concat, alias for torch.cat +Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) { + return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim)); +} + +Tensor concat(TensorList tensors, Dimname dim) { + return at::cat(tensors, dimname_to_position(tensors[0], dim)); +} + +Tensor & concat_out(TensorList tensors, int64_t dim, Tensor & result) { + return at::cat_out(result, tensors, dim); +} + +Tensor concat(TensorList tensors, int64_t dim) { + return at::cat(tensors, dim); +} + static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) { if (s1.size() != s2.size()) { return false; @@ -609,7 +628,13 @@ std::vector tensor_split(const Tensor& self, const Tensor& tensor_indice return self.tensor_split(sections, dim); } else { auto indices_data = tensor_indices_or_sections.data_ptr(); - std::vector indices(indices_data, indices_data + tensor_indices_or_sections.numel()); + auto stride = tensor_indices_or_sections.stride(0); + auto numel = tensor_indices_or_sections.numel(); + std::vector indices(numel); + for (size_t offset = 0; offset < numel; offset++) { + // indices tensor could be non-contiguous + indices[offset] = *(indices_data + offset * stride); + } return self.tensor_split(indices, dim); } } @@ -1203,12 +1228,15 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) if (dim < sparse_dim) { - auto dim_indices = indices[dim]; + auto cpu_dim_indices = indices[dim].to(c10::kCPU).contiguous(); + int64_t* cpu_dim_indices_ptr = cpu_dim_indices.data_ptr(); + auto cpu_index = index.to(c10::kCPU).contiguous(); + int64_t* cpu_index_ptr = cpu_index.data_ptr(); std::vector zindices; std::vector iindices; int64_t new_nnz = 0; - for (const auto i : c10::irange(new_sizes[dim])) { - auto idx = index[i].item(); + for (int64_t i = 0; i < new_sizes[dim]; i++) { + int64_t idx = cpu_index_ptr[i]; if (idx < -size || idx >= size) { TORCH_CHECK_INDEX(false, "index_select(): index contains ", idx, " that is out of range for tensor of size ", self.sizes(), " at dimension ", dim); @@ -1216,8 +1244,8 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) if (idx < 0) { idx += size; } - for (const auto j : c10::irange(nnz)) { - auto jdx = dim_indices[j].item(); + for (int64_t j = 0; j < nnz; j++) { + int64_t jdx = cpu_dim_indices_ptr[j]; if (idx == jdx) { new_nnz++; iindices.push_back(i); @@ -1488,9 +1516,8 @@ bool inline maybe_native_stack(Tensor& result, TensorList tensors, int64_t dim) result_sizes.insert(result_sizes.begin() + dim, tensors.size()); // skip resizing if size of result is same as expected - if (result.sizes() != result_sizes) { - result.resize_(result_sizes); - } + // raise a warning while resizing if output has one or more elements + at::native::resize_output(result, result_sizes); stack_serial_stub(kCPU, result, tensors, dim); return true; } @@ -2033,6 +2060,8 @@ Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname o Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) { auto positions = dimnames_to_positions(self, dims); + TORCH_CHECK(positions.size() > 0, + "flatten(tensor, dims, out_dim): dims cannot be empty"); for (const auto i : c10::irange(positions.size() - 1)) { if (positions[i] + 1 == positions[i + 1]) continue; TORCH_CHECK(positions[i] + 1 == positions[i + 1], diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index d5052a77f5b62..b7e596392c716 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -191,7 +191,6 @@ TORCH_IMPL_FUNC(polygamma_out) } TORCH_IMPL_FUNC(signbit_out) (const Tensor& self, const Tensor& result) { - at::native::resize_output(result, self.sizes()); if (self.dtype() == at::kBool) { result.fill_(false); } else { diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index e50b053949d37..602abcebbe3a0 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -251,12 +251,16 @@ static inline scalar_t area_pixel_compute_scale( bool align_corners, const c10::optional scale) { // see Note [area_pixel_compute_scale] - if (output_size > 1) { - return align_corners - ? static_cast(input_size - 1) / (output_size - 1) - : compute_scales_value(scale, input_size, output_size); - } else { - return scalar_t(0); + if(align_corners){ + if(output_size > 1) { + return static_cast(input_size - 1) / (output_size - 1); + } + else { + return static_cast(0); + } + } + else{ + return compute_scales_value(scale, input_size, output_size); } } diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index ae1403d1a25d1..34b54719fe502 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -401,41 +401,80 @@ void hardswish_backward_kernel(TensorIterator& iter) { } static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&] { - using Vec = Vectorized; - auto zero_vec = Vec((scalar_t)(0)); - auto one_vec = Vec((scalar_t)(1)); - scalar_t negval = negval_.to(); - Vec negval_v = Vec(negval); + if (iter.common_dtype() == kBFloat16) { + auto zero_vec = Vectorized((float)(0)); + auto one_vec = Vectorized((float)(1)); + float negval = negval_.to(); + Vectorized negval_v = Vectorized(negval); cpu_kernel_vec( iter, - [&](scalar_t a) -> scalar_t { - return a > scalar_t(0) ? a : a * negval; + [&](BFloat16 a) -> BFloat16 { + return float(a) > float(0) ? float(a) : float(a) * negval; }, - [&](Vec a) -> Vec { - auto r = Vec::blendv(negval_v, one_vec, a > zero_vec); - return a * r; + [&](Vectorized a) -> Vectorized { + Vectorized a0, a1; + std::tie(a0, a1) = convert_bfloat16_float(a); + auto res0 = a0 * (Vectorized::blendv(negval_v, one_vec, a0 > zero_vec)); + auto res1 = a1 * (Vectorized::blendv(negval_v, one_vec, a1 > zero_vec)); + return convert_float_bfloat16(res0, res1); }); - }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&] { + using Vec = Vectorized; + auto zero_vec = Vec((scalar_t)(0)); + auto one_vec = Vec((scalar_t)(1)); + scalar_t negval = negval_.to(); + Vec negval_v = Vec(negval); + cpu_kernel_vec( + iter, + [&](scalar_t a) -> scalar_t { + return a > scalar_t(0) ? a : a * negval; + }, + [&](Vec a) -> Vec { + auto r = Vec::blendv(negval_v, one_vec, a > zero_vec); + return a * r; + }); + }); + } } static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&] { - using Vec = Vectorized; - auto zero_vec = Vec((scalar_t)(0)); - auto one_vec = Vec((scalar_t)(1)); - scalar_t negval = negval_.to(); - Vec negval_v = Vec(negval); + if (iter.common_dtype() == kBFloat16) { + auto zero_vec = Vectorized((float)(0)); + auto one_vec = Vectorized((float)(1)); + float negval = negval_.to(); + Vectorized negval_v = Vectorized(negval); cpu_kernel_vec( - iter, - [&](scalar_t a, scalar_t b) -> scalar_t { - return a > scalar_t(0) ? b : b * negval; - }, - [&](Vec a, Vec b) -> Vec { - auto r = Vec::blendv(negval_v, one_vec, a > zero_vec); - return b * r; - }); - }); + iter, + [&](BFloat16 a, BFloat16 b) -> BFloat16 { + return float(a) > float(0) ? float(b) : float(b) * negval; + }, + [&](Vectorized a, Vectorized b) -> Vectorized { + Vectorized a0, a1, b0, b1; + std::tie(a0, a1) = convert_bfloat16_float(a); + std::tie(b0, b1) = convert_bfloat16_float(b); + auto res0 = b0 * (Vectorized::blendv(negval_v, one_vec, a0 > zero_vec)); + auto res1 = b1 * (Vectorized::blendv(negval_v, one_vec, a1 > zero_vec)); + return convert_float_bfloat16(res0, res1); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_backward_cpu", [&] { + using Vec = Vectorized; + auto zero_vec = Vec((scalar_t)(0)); + auto one_vec = Vec((scalar_t)(1)); + scalar_t negval = negval_.to(); + Vec negval_v = Vec(negval); + cpu_kernel_vec( + iter, + [&](scalar_t a, scalar_t b) -> scalar_t { + return a > scalar_t(0) ? b : b * negval; + }, + [&](Vec a, Vec b) -> Vec { + auto r = Vec::blendv(negval_v, one_vec, a > zero_vec); + return b * r; + }); + }); + } } void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { @@ -480,7 +519,7 @@ void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, con }); } -void glu_kernel(TensorIterator& iter) { +void glu_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&] { using Vec = Vectorized; const scalar_t one_val(1); diff --git a/aten/src/ATen/native/cpu/AvgPoolKernel.cpp b/aten/src/ATen/native/cpu/AvgPoolKernel.cpp index 2aa075f5933bd..2bee0206ff6b5 100644 --- a/aten/src/ATen/native/cpu/AvgPoolKernel.cpp +++ b/aten/src/ATen/native/cpu/AvgPoolKernel.cpp @@ -14,9 +14,9 @@ template void cpu_avg_pool( const Tensor& output_, const Tensor& input_, - int kW, int kH, - int dW, int dH, - int padW, int padH, + int64_t kW, int64_t kH, + int64_t dW, int64_t dH, + int64_t padW, int64_t padH, bool count_include_pad, c10::optional divisor_override) { auto input = input_.contiguous(); @@ -98,9 +98,9 @@ template void cpu_avg_pool_channels_last( const Tensor& output_, const Tensor& input_, - int kW, int kH, - int dW, int dH, - int padW, int padH, + int64_t kW, int64_t kH, + int64_t dW, int64_t dH, + int64_t padW, int64_t padH, bool count_include_pad, c10::optional divisor_override) { TORCH_CHECK(input_.ndimension() == 4, @@ -359,9 +359,9 @@ void cpu_avg_pool_backward_channels_last( void avg_pool2d_kernel_impl( const Tensor& output, const Tensor& input, - int kW, int kH, - int dW, int dH, - int padW, int padH, + int64_t kW, int64_t kH, + int64_t dW, int64_t dH, + int64_t padW, int64_t padH, bool count_include_pad, c10::optional divisor_override) { switch (input.suggest_memory_format()) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 2a8f73cb88dd0..16efa2511899f 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -684,19 +684,35 @@ void sigmoid_backward_kernel(TensorIteratorBase& iter) { return a * ((one_vec - b) * b).conj(); }); }); + } else if (iter.dtype() == kBFloat16) { + auto one_vec = Vectorized((float)(1)); + cpu_kernel_vec( + iter, + [=](BFloat16 a, BFloat16 b) -> BFloat16 { + float a0 = static_cast(a); + float b0 = static_cast(b); + return a0 * (float(1) - b0) * b0; + }, + [=](Vectorized a, Vectorized b) { + Vectorized a0, a1, b0, b1; + std::tie(a0, a1) = convert_bfloat16_float(a); + std::tie(b0, b1) = convert_bfloat16_float(b); + a0 = a0 * (one_vec - b0) * b0; + a1 = a1 * (one_vec - b1) * b1; + return convert_float_bfloat16(a0, a1); + }); } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "sigmoid_backward_cpu", [&]() { - auto one_vec = Vectorized((scalar_t)(1)); - cpu_kernel_vec( - iter, - [=](scalar_t a, scalar_t b) -> scalar_t { - return a * (scalar_t(1) - b) * b; - }, - [=](Vectorized a, Vectorized b) { - return a * (one_vec - b) * b; - }); + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "sigmoid_backward_cpu", [&]() { + auto one_vec = Vectorized((scalar_t)(1)); + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t(1) - b) * b; + }, + [=](Vectorized a, Vectorized b) { + return a * (one_vec - b) * b; }); + }); } } @@ -754,15 +770,32 @@ void tanh_backward_kernel(TensorIteratorBase& iter) { if (isComplexType(iter.dtype())) { AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { auto one_vec = Vectorized(scalar_t{1}); + cpu_kernel_vec( + iter, + [=](scalar_t a, scalar_t b) -> scalar_t { + return a * std::conj(scalar_t{1} - b * b); + }, + [=](Vectorized a, Vectorized b) { + return a * (one_vec - b * b).conj(); + }); + }); + } else if (iter.dtype() == kBFloat16) { + auto one_vec = Vectorized(float{1}); cpu_kernel_vec( iter, - [=](scalar_t a, scalar_t b) -> scalar_t { - return a * std::conj(scalar_t{1} - b * b); + [=](BFloat16 a, BFloat16 b) -> BFloat16 { + float a0 = float(a); + float b0 = float(b); + return a0 * (float{1} - b0 * b0); }, - [=](Vectorized a, Vectorized b) { - return a * (one_vec - b * b).conj(); + [=](Vectorized a, Vectorized b) { + Vectorized a0, a1, b0, b1; + std::tie(a0, a1) = convert_bfloat16_float(a); + std::tie(b0, b1) = convert_bfloat16_float(b); + a0 = a0 * (one_vec - b0 * b0); + a1 = a1 * (one_vec - b1 * b1); + return convert_float_bfloat16(a0, a1); }); - }); } else { AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() { auto one_vec = Vectorized(scalar_t{1}); diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h index 66bd31fa74d45..15b1916b9892c 100644 --- a/aten/src/ATen/native/cpu/DistributionTemplates.h +++ b/aten/src/ATen/native/cpu/DistributionTemplates.h @@ -308,7 +308,7 @@ struct ExponentialKernel { template void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); using self_t = scalar_t; @@ -325,7 +325,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) { return static_cast(bernoulli(generator)); }); } else { - AT_DISPATCH_FLOATING_TYPES(p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] { + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, p_.scalar_type(), "bernoulli_tensor_cpu_p_", [&] { using p_t = scalar_t; cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t { at::bernoulli_distribution bernoulli(p_val); @@ -338,7 +338,7 @@ void bernoulli_kernel(Tensor& self, const Tensor& p_, RNG generator) { template void bernoulli_kernel(Tensor& self, double p, RNG generator) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] { // See Note [Acquire lock when using random generators] std::lock_guard lock(generator->mutex_); auto iter = TensorIterator::borrowing_nullary_op(self); diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp new file mode 100644 index 0000000000000..5a7b03128766b --- /dev/null +++ b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp @@ -0,0 +1,385 @@ +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +void cpu_max_unpool( + Tensor& output_, + const Tensor& input, + const Tensor& indices) { + auto output = output_.contiguous(); + + auto input_data = input.data_ptr(); + auto indices_data = indices.data_ptr(); + auto output_data = output.data_ptr(); + + // NB: input tensor dimensions: + // MaxUnpool2d: + // dim = 3: CHW + // dim = 4: NCHW + // MaxUnpool3d: + // dim = 4: CDHW + // dim = 5: NCDHW + + int64_t numel = input.numel(); + int64_t ndim = input.ndimension(); + + // treat batch size and channels as one dimension + // and the feature map as another dimension + int64_t channels, output_depth, output_height, output_width; + if (is_3d) { + TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d: expect input to be 4d or 5d tensor."); + channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1); + output_depth = output.size(-3); + output_height = output.size(-2); + output_width = output.size(-1); + } else { + TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d: expect input to be 3d or 4d tensor."); + channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1); + output_depth = 1; + output_height = output.size(-2); + output_width = output.size(-1); + } + int64_t input_image_size = numel / channels; + int64_t output_image_size = output.numel() / channels; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, C, D, H, W: [channels, input_image_size] + at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) { + int64_t c = 0; + int64_t ip = 0; + data_index_init(begin, c, channels, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* output_ptr = output_data + c * output_image_size; + + int64_t maxp = indices_data[i]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + output_ptr[maxp] = input_data[i]; + } + + // move on to next input index + data_index_step(c, channels, ip, input_image_size); + } + }); + + if (has_error) { + if (is_3d) { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_depth, + "x", output_height, "x", output_width); + (void)error_index; + } else { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_height, + "x", output_width); + (void)error_index; + } + } + + if (!output_.is_contiguous()) { + output_.copy_(output); + } +} + +template +void cpu_max_unpool_channels_last( + Tensor& output_, + const Tensor& input, + const Tensor& indices) { + TORCH_CHECK(input.ndimension() == 4, + "max_unpool2d with channels last format supports tensors with 4 dims"); + auto memory_format = at::MemoryFormat::ChannelsLast; + auto output = output_.contiguous(memory_format); + + auto input_data = input.data_ptr(); + auto indices_data = indices.data_ptr(); + auto output_data = output.data_ptr(); + + int64_t nbatch = input.size(0); + int64_t channels = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + int64_t output_height = output.size(2); + int64_t output_width = output.size(3); + int64_t input_image_size = input_height * input_width; + int64_t output_image_size = output_height * output_width; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, H, W + at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t ip = 0; + data_index_init(begin, n, nbatch, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* input_ptr = input_data + i * channels; + int64_t* indices_ptr = indices_data + i * channels; + scalar_t* output_ptr = output_data + n * output_image_size * channels; + + // can't do scatter on avx2 (only available on avx512) + for (int64_t c = 0; c < channels; c++) { + int64_t maxp = indices_ptr[c]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + output_ptr[maxp * channels + c] = input_ptr[c]; + } + } + + // move on to next input index + data_index_step(n, nbatch, ip, input_image_size); + } + }); + + if (has_error) { + AT_ERROR("Found an invalid max index: ", error_index, + " (output volumes are of size ", output_height, + "x", output_width); + (void)error_index; + } + + if (!output_.is_contiguous(memory_format)) { + output_.copy_(output); + } +} + +template +void cpu_max_unpool_backward( + Tensor& grad_input_, + const Tensor& grad_output, + const Tensor& indices) { + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto indices_data = indices.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + + int64_t numel = grad_input.numel(); + int64_t ndim = grad_output.ndimension(); + + // treat batch size and channels as one dimension + // and the feature map as another dimension + int64_t channels, output_depth, output_height, output_width; + if (is_3d) { + TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d_backward: expect grad_output to be 4d or 5d tensor."); + channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + output_depth = grad_output.size(-3); + output_height = grad_output.size(-2); + output_width = grad_output.size(-1); + } else { + TORCH_CHECK(ndim == 3 || ndim == 4, "MaxUnpool2d_backward: expect grad_output to be 3d or 4d tensor."); + channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); + output_depth = 1; + output_height = grad_output.size(-2); + output_width = grad_output.size(-1); + } + int64_t input_image_size = numel / channels; + int64_t output_image_size = grad_output.numel() / channels; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, C, D, H, W + at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) { + int64_t c = 0; + int64_t ip = 0; + data_index_init(begin, c, channels, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* grad_output_ptr = grad_output_data + c * output_image_size; + + int64_t maxp = indices_data[i]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + grad_input_data[i] = grad_output_ptr[maxp]; + } + + // move on to next input index + data_index_step(c, channels, ip, input_image_size); + } + }); + + if (has_error) { + if (is_3d) { + AT_ERROR("invalid max index ", error_index, + ", odepth= ", output_depth, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } else { + AT_ERROR("invalid max index ", error_index, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } + } + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +template +void cpu_max_unpool_backward_channels_last( + Tensor& grad_input_, + const Tensor& grad_output, + const Tensor& indices) { + TORCH_CHECK(grad_output.ndimension() == 4, + "max_unpool2d backward with channels last format supports tensors with 4 dims."); + auto memory_format = at::MemoryFormat::ChannelsLast; + auto grad_input = grad_input_.contiguous(memory_format); + + auto grad_input_data = grad_input.data_ptr(); + auto grad_output_data = grad_output.data_ptr(); + auto indices_data = indices.data_ptr(); + + int64_t nbatch = grad_input.size(0); + int64_t channels = grad_input.size(1); + int64_t input_height = grad_input.size(2); + int64_t input_width = grad_input.size(3); + int64_t output_height = grad_output.size(2); + int64_t output_width = grad_output.size(3); + int64_t input_image_size = input_height * input_width; + int64_t output_image_size = output_height * output_width; + + bool has_error = false; + int64_t error_index = 0; + + // parallel on dim N, H, W + at::parallel_for(0, nbatch * input_image_size, 0, [&](int64_t begin, int64_t end) { + int64_t n = 0; + int64_t ip = 0; + data_index_init(begin, n, nbatch, ip, input_image_size); + + for (int64_t i = begin; i < end; i++) { + scalar_t* grad_output_ptr = grad_output_data + n * output_image_size * channels; + scalar_t* grad_input_ptr = grad_input_data + i * channels; + int64_t* indices_ptr = indices_data + i * channels; + + for (int64_t c = 0; c < channels; c++) { + int64_t maxp = indices_ptr[c]; + if (maxp < 0 || maxp >= output_image_size) { + #pragma omp critical + { + has_error = true; + error_index = maxp; + } + } else { + grad_input_ptr[c] = grad_output_ptr[maxp * channels + c]; + } + } + + // move on to next input index + data_index_step(n, nbatch, ip, input_image_size); + } + }); + + if (has_error) { + AT_ERROR("invalid max index ", error_index, + ", owidth= ", output_width, + ", oheight= ", output_height); + (void)error_index; + } + + if (!grad_input_.is_contiguous(memory_format)) { + grad_input_.copy_(grad_input); + } +} + +void max_unpool2d_kernel_impl( + Tensor& output, + const Tensor& input, + const Tensor& indices) { + switch(input.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d", [&] { + cpu_max_unpool(output, input, indices); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool2d_channels_last", [&] { + cpu_max_unpool_channels_last(output, input, indices); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void max_unpool3d_kernel_impl( + Tensor& output, + const Tensor& input, + const Tensor& indices) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_unpool3d", [&] { + cpu_max_unpool(output, input, indices); + }); +} + +void max_unpool2d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& indices) { + switch(grad_output.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward", [&] { + cpu_max_unpool_backward(grad_input, grad_output, indices); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool2d_backward_channels_last", [&] { + cpu_max_unpool_backward_channels_last(grad_input, grad_output, indices); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } +} + +void max_unpool3d_backward_kernel_impl( + Tensor& grad_input, + const Tensor& grad_output, + const Tensor& indices) { + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_unpool3d_backward", [&] { + cpu_max_unpool_backward(grad_input, grad_output, indices); + }); +} + +} // anonymous namespace + +REGISTER_DISPATCH(max_unpool2d_kernel, &max_unpool2d_kernel_impl); +REGISTER_DISPATCH(max_unpool2d_backward_kernel, &max_unpool2d_backward_kernel_impl); +REGISTER_DISPATCH(max_unpool3d_kernel, &max_unpool3d_kernel_impl); +REGISTER_DISPATCH(max_unpool3d_backward_kernel, &max_unpool3d_backward_kernel_impl); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.h b/aten/src/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000000000..00fbeb64213d6 --- /dev/null +++ b/aten/src/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,16 @@ +#include +#include +#include + +#pragma once + +namespace at { namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_backward_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_backward_kernel); + +}} // at::native diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp index e0807d14b1b5a..0d0508adb7c11 100644 --- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp @@ -12,38 +12,82 @@ namespace { static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) { ScalarType dtype = iter.dtype(0); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcmul_cpu_out", [&] { - scalar_t scalar_val = value.to(); - auto scalar_vec = Vectorized(scalar_val); + if (iter.dtype() == kBFloat16) { + float float_val = value.to(); + auto float_vec = Vectorized(float_val); cpu_kernel_vec( iter, - [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t { - return self_val + scalar_val * t1_val * t2_val; + [=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 { + return float(self_val) + float_val * float(t1_val) * float(t2_val); }, - [=](Vectorized self_vec, - Vectorized t1_vec, - Vectorized t2_vec) { - return self_vec + scalar_vec * t1_vec * t2_vec; + [=](Vectorized self_vec, + Vectorized t1_vec, + Vectorized t2_vec) { + Vectorized self_vec0, self_vec1; + std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec); + Vectorized t1_vec0, t1_vec1, t2_vec0, t2_vec1; + std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec); + std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec); + self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0; + self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1; + return convert_float_bfloat16(self_vec0, self_vec1); }); - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcmul_cpu_out", [&] { + scalar_t scalar_val = value.to(); + auto scalar_vec = Vectorized(scalar_val); + cpu_kernel_vec( + iter, + [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t { + return self_val + scalar_val * t1_val * t2_val; + }, + [=](Vectorized self_vec, + Vectorized t1_vec, + Vectorized t2_vec) { + return self_vec + scalar_vec * t1_vec * t2_vec; + }); + }); + } } static void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) { ScalarType dtype = iter.dtype(0); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcdiv_cpu_out", [&] { - scalar_t scalar_val = value.to(); - auto scalar_vec = Vectorized(scalar_val); + if (dtype == kBFloat16) { + float float_val = value.to(); + auto float_vec = Vectorized(float_val); cpu_kernel_vec( iter, - [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t { - return self_val + scalar_val * t1_val / t2_val; + [=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 { + return float(self_val) + float_val * float(t1_val) / float(t2_val); }, - [=](Vectorized self_vec, - Vectorized t1_vec, - Vectorized t2_vec) { - return self_vec + scalar_vec * t1_vec / t2_vec; + [=](Vectorized self_vec, + Vectorized t1_vec, + Vectorized t2_vec) { + Vectorized self_vec0, self_vec1; + std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec); + Vectorized t1_vec0, t1_vec1, t2_vec0, t2_vec1; + std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec); + std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec); + self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0; + self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1; + return convert_float_bfloat16(self_vec0, self_vec1); }); - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcdiv_cpu_out", [&] { + scalar_t scalar_val = value.to(); + auto scalar_vec = Vectorized(scalar_val); + cpu_kernel_vec( + iter, + [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t { + return self_val + scalar_val * t1_val / t2_val; + }, + [=](Vectorized self_vec, + Vectorized t1_vec, + Vectorized t2_vec) { + return self_vec + scalar_vec * t1_vec / t2_vec; + }); + }); + } } static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) { diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index c32efeb276bd7..2ab92fbdb2bb2 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -100,15 +100,9 @@ struct _cpu_scatter_gather_dim_loop { template struct cpu_scatter_gather_base_kernel { template - void operator()(Tensor& self, int64_t dim, + void operator()(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const std::string& method_name, func_t& kernel_func) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } - - dim = maybe_wrap_dim(dim, self.dim()); auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); auto index_strides = ensure_nonempty_vec(index.strides().vec()); @@ -193,22 +187,10 @@ struct cpu_scatter_gather_base_kernel { } template - void operator()(Tensor& self, int64_t dim, + void operator()(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const std::string& method_name, func_t& kernel_func) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } - - dim = maybe_wrap_dim(dim, self.dim()); - - scatter_gather_dtype_check(method_name, self, index, src); - if (!is_scatter_like) { - gather_shape_check(self, dim, index, src); - } - auto iter = TensorIteratorConfig() .check_all_same_dtype(false) .resize_outputs(false) @@ -292,30 +274,30 @@ struct cpu_scatter_gather_base_kernel { } }; -void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { +void gather_cpu_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { cpu_scatter_gather_base_kernel()( result, dim, index, self, "gather_out_cpu", tensor_assign); } -void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { +void scatter_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { cpu_scatter_gather_base_kernel<>()( self, dim, index, src, "scatter_cpu_", tensor_assign); } -void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) { +void scatter_fill_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) { cpu_scatter_gather_base_kernel<>()( self, dim, index, value, "scatter_fill_cpu_", tensor_assign); } -void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { +void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { cpu_scatter_gather_base_kernel<>()( self, dim, index, src, "scatter_add_", reduce_add); } -void scatter_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& index, +void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index, const Tensor& src, const SCATTER_GATHER_OP& reduce) { switch (reduce) { case SCATTER_GATHER_OP::REDUCE_ADD : @@ -329,7 +311,7 @@ void scatter_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& in } } -void scatter_scalar_reduce_cpu_kernel(Tensor& self, const int64_t dim, const Tensor& index, +void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index, const Scalar& value, const SCATTER_GATHER_OP& reduce) { switch (reduce) { case SCATTER_GATHER_OP::REDUCE_ADD : diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 6288cec2ea3b3..f86f0a349dace 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -35,18 +35,36 @@ namespace CPU_CAPABILITY { using namespace vec; static void sigmoid_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.common_dtype(), "sigmoid_cpu", [&]() { + if (iter.common_dtype() == kBFloat16) { cpu_kernel_vec( iter, - [=](scalar_t a) -> scalar_t { return (static_cast(1) / (static_cast(1) + std::exp((-a)))); }, - [=](Vectorized a) { - a = Vectorized(static_cast(0)) - a; - a = a.exp(); - a = Vectorized(static_cast(1)) + a; - a = a.reciprocal(); - return a; + [=](BFloat16 a) -> BFloat16 { + float a0 = static_cast(a); + return static_cast(1) / (static_cast(1) + std::exp((-a0))); + }, + [=](Vectorized a) { + Vectorized a0, a1; + std::tie(a0, a1) = convert_bfloat16_float(a); + a0 = (Vectorized(static_cast(1)) + a0.neg().exp()).reciprocal(); + a1 = (Vectorized(static_cast(1)) + a1.neg().exp()).reciprocal(); + return convert_float_bfloat16(a0, a1); }); - }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "sigmoid_cpu", [&]() { + cpu_kernel_vec( + iter, + [=](scalar_t a) -> scalar_t { + return (static_cast(1) / (static_cast(1) + std::exp((-a)))); + }, + [=](Vectorized a) { + a = Vectorized(static_cast(0)) - a; + a = a.exp(); + a = Vectorized(static_cast(1)) + a; + a = a.reciprocal(); + return a; + }); + }); + } } #if AT_MKL_ENABLED() @@ -322,7 +340,7 @@ static void sinc_kernel(TensorIteratorBase& iter) { } static void sinh_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "sinh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "sinh_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return std::sinh(a); }, @@ -331,7 +349,7 @@ static void sinh_kernel(TensorIteratorBase& iter) { } static void cosh_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.dtype(), "cosh_cpu", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, iter.dtype(), "cosh_cpu", [&]() { cpu_kernel_vec( iter, [=](scalar_t a) -> scalar_t { return std::cosh(a); }, @@ -407,7 +425,7 @@ static void nan_to_num_kernel( c10::optional nan, c10::optional pos_inf, c10::optional neg_inf) { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.dtype(), "nan_to_num", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "nan_to_num", [&]() { scalar_t nan_replacement = static_cast(nan.value_or(0.)); scalar_t pos_inf_replacement = pos_inf.has_value() ? static_cast(pos_inf.value()) @@ -470,7 +488,7 @@ void bernoulli_scalar_kernel(Tensor &self, double p, c10::optional ge int64_t n = self.numel(); bool contig = self.is_contiguous(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] { at::Tensor tmp_int_tensor; if (std::is_same::value && contig) { tmp_int_tensor = self; @@ -586,7 +604,7 @@ static void entr_kernel(TensorIteratorBase& iter) { } static void frexp_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND(kHalf, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, // The iter.dtype() here is the dtype of mantissa output. // It's a floating point type and must be the same as the input's dtype. iter.dtype(), diff --git a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp index 2d1275538d89f..75037606d3ff4 100644 --- a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp @@ -611,48 +611,38 @@ void batch_norm_cpu_backward_channels_last_impl(Tensor& grad_input, Tensor& grad void batch_norm_cpu_kernel(Tensor& output, const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& save_mean, const Tensor& save_invstd, const Tensor& running_mean, const Tensor& running_var, bool train, double eps) { - switch (input.suggest_memory_format()) { - case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] { - batch_norm_cpu_contiguous_impl(output, input, weight, bias, - save_mean, save_invstd, running_mean, running_var, train, eps); - }); - break; - } - case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] { - batch_norm_cpu_channels_last_impl(output, input, weight, bias, - save_mean, save_invstd, running_mean, running_var, train, eps); - }); - break; - } - default: - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + if (input.is_contiguous()) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] { + batch_norm_cpu_contiguous_impl(output, input, weight, bias, + save_mean, save_invstd, running_mean, running_var, train, eps); + }); + } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] { + batch_norm_cpu_channels_last_impl(output, input, weight, bias, + save_mean, save_invstd, running_mean, running_var, train, eps); + }); + } else { + TORCH_CHECK(false, "batch_norm_cpu_kernel: expecting input to be contiguous."); } } void batch_norm_cpu_collect_stats_kernel( Tensor& mean, Tensor& var_sum, const Tensor& input) { int64_t image_size = input.numel() / input.size(0) / input.size(1); - switch (input.suggest_memory_format()) { - case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] { - if (image_size == 1) { // NC11 is also channels last - batch_norm_cpu_collect_stats_channels_last_impl(mean, var_sum, input); - } else { - batch_norm_cpu_collect_stats_contiguous_impl(mean, var_sum, input); - } - }); - break; - } - case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] { + if (input.is_contiguous()) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] { + if (image_size == 1) { // NC11 is also channels last batch_norm_cpu_collect_stats_channels_last_impl(mean, var_sum, input); - }); - break; - } - default: - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } else { + batch_norm_cpu_collect_stats_contiguous_impl(mean, var_sum, input); + } + }); + } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] { + batch_norm_cpu_collect_stats_channels_last_impl(mean, var_sum, input); + }); + } else { + TORCH_CHECK(false, "batch_norm_cpu_collect_stats_kernel: expecting input to be contiguous."); } } @@ -661,28 +651,23 @@ void batch_norm_cpu_backward_kernel(Tensor& grad_input, Tensor& grad_weight, Ten const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double eps) { int64_t image_size = input.numel() / input.size(0) / input.size(1); - switch (input.suggest_memory_format()) { - case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] { - if (image_size == 1) { // NC11 is also channels last - batch_norm_cpu_backward_channels_last_impl(grad_input, grad_weight, grad_bias, - grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); - } else { - batch_norm_cpu_backward_contiguous_impl(grad_input, grad_weight, grad_bias, - grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); - } - }); - break; - } - case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] { + if (input.is_contiguous()) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] { + if (image_size == 1) { // NC11 is also channels last batch_norm_cpu_backward_channels_last_impl(grad_input, grad_weight, grad_bias, grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); - }); - break; - } - default: - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } else { + batch_norm_cpu_backward_contiguous_impl(grad_input, grad_weight, grad_bias, + grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); + } + }); + } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] { + batch_norm_cpu_backward_channels_last_impl(grad_input, grad_weight, grad_bias, + grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps); + }); + } else { + TORCH_CHECK(false, "batch_norm_cpu_backward_kernel: expecting input to be contiguous."); } } diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index 290a6315da445..fb8db7e61800f 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -74,6 +74,136 @@ void GroupNormKernelImplInternal( }); } +template +void GroupNormKernelImplChannelsLastInternal( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t N, + int64_t C, + int64_t HxW, + int64_t group, + T eps, + Tensor& Y, + Tensor& mean, + Tensor& rstd) { + TORCH_CHECK(X.numel() == N * C * HxW); + TORCH_CHECK(!gamma.defined() || gamma.numel() == C); + TORCH_CHECK(!beta.defined() || beta.numel() == C); + const int64_t G = group; + const int64_t D = C / G; + const T* X_data = X.data_ptr(); + const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; + T* Y_data = Y.data_ptr(); + T* mean_data = mean.data_ptr(); + T* rstd_data = rstd.data_ptr(); + const T s = T(1) / static_cast(D * HxW); + const bool gamma_null = (gamma_data == nullptr); + const bool beta_null = beta_data == nullptr; + + // temp buffer holding x and x2 + Tensor buffer = at::empty({N, 2 * C}, X.options()).zero_(); + T* buffer_data = buffer.data_ptr(); + + using Vec = vec::Vectorized; + at::parallel_for(0, N, 1, [&](int64_t start, int64_t end) { + constexpr int64_t K = Vec::size(); + const int64_t inner_size = C / K * K; + for (int64_t n = start; n < end; ++n) { + T* mean_ptr = buffer_data + n * 2 * C; + T* rstd_ptr = mean_ptr + C; + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + const Vec x_vec = Vec::loadu(X_ptr + j); + Vec mean_vec = Vec::loadu(mean_ptr + j) + x_vec; + Vec rstd_vec = Vec::loadu(rstd_ptr + j) + x_vec * x_vec; + mean_vec.store(mean_ptr + j); + rstd_vec.store(rstd_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + mean_ptr[j] += X_ptr[j]; + rstd_ptr[j] += X_ptr[j] * X_ptr[j]; + } + } + + for (int64_t g = 0; g < G; ++g) { + T mean_val = T(0); + T rstd_val = T(0); + for (int64_t d = 0; d < D; ++d) { + mean_val += mean_ptr[g * D + d]; + rstd_val += rstd_ptr[g * D + d]; + } + mean_val *= s; + rstd_val = std::max(rstd_val * s - mean_val * mean_val, T(0)); + rstd_val = T(1) / std::sqrt(rstd_val + eps); + + // continue to use the temp buffer for mean and rstd value, + // so that we can vectorize the following math on entire C dimension. + for (int64_t d = 0; d < D; ++d) { + mean_ptr[g * D + d] = mean_val; + rstd_ptr[g * D + d] = rstd_val; + } + + mean_data[n * G + g] = mean_val; + rstd_data[n * G + g] = rstd_val; + } + + // expand gamma_null and beta_null to reduce if-else on critial path. + if (!gamma_null && !beta_null) { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec scale_vec = Vec::loadu(rstd_ptr + j) * Vec::loadu(gamma_data + j); + Vec bias_vec = Vec::loadu(beta_data + j) - scale_vec * Vec::loadu(mean_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) + bias_vec; + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j] * gamma_data[j]; + T bias = -scale * mean_ptr[j] + beta_data[j]; + Y_ptr[j] = scale * X_ptr[j] + bias; + } + } + } else if (gamma_null && beta_null) { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec scale_vec = Vec::loadu(rstd_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) - scale_vec * Vec::loadu(mean_ptr + j); + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j]; + Y_ptr[j] = scale * X_ptr[j] -scale * mean_ptr[j]; + } + } + } else { + for (int64_t i = 0; i < HxW; ++i) { + const T* X_ptr = X_data + n * HxW * C + i * C; + T* Y_ptr = Y_data + n * HxW * C + i * C; + for (int64_t j = 0; j < inner_size; j += K) { + Vec gamma_vec = gamma_null ? Vec(1) : Vec::loadu(gamma_data + j); + Vec beta_vec = beta_null ? Vec(0) : Vec::loadu(beta_data + j); + Vec scale_vec = Vec::loadu(rstd_ptr + j) * gamma_vec; + Vec bias_vec = beta_vec - scale_vec * Vec::loadu(mean_ptr + j); + Vec y_vec = scale_vec * Vec::loadu(X_ptr + j) + bias_vec; + y_vec.store(Y_ptr + j); + } + for (int64_t j = inner_size; j < C; ++j) { + T scale = rstd_ptr[j] * (gamma_null ? T(1) : gamma_data[j]); + T bias = -scale * mean_ptr[j] + (beta_null ? T(0) : beta_data[j]); + Y_ptr[j] = scale * X_ptr[j] + bias; + } + } + } + } + }); +} + void GroupNormKernelImpl( const Tensor& X, const Tensor& gamma, @@ -86,20 +216,24 @@ void GroupNormKernelImpl( Tensor& Y, Tensor& mean, Tensor& rstd) { - AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { - GroupNormKernelImplInternal( - X, - gamma, - beta, - N, - C, - HxW, - group, - static_cast(eps), - Y, - mean, - rstd); - }); + switch (X.suggest_memory_format()) { + case at::MemoryFormat::Contiguous: { + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { + GroupNormKernelImplInternal( + X, gamma, beta, N, C, HxW, group, static_cast(eps), Y, mean, rstd); + }); + break; + } + case at::MemoryFormat::ChannelsLast: { + AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "GroupNormKernelImpl", [&]() { + GroupNormKernelImplChannelsLastInternal( + X, gamma, beta, N, C, HxW, group, static_cast(eps), Y, mean, rstd); + }); + break; + } + default: + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } } template diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 1229149d76aee..7c8783028a5ac 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -28,7 +28,7 @@ namespace native { // ----------------------------------- // glu forward // ----------------------------------- -void glu_kernel(TensorIterator& iter) { +void glu_kernel(TensorIteratorBase& iter) { AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { using acc_t = at::acc_type; gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { diff --git a/aten/src/ATen/native/cuda/AmpKernels.cu b/aten/src/ATen/native/cuda/AmpKernels.cu index a5d8a643648e7..c89d8a09e8d1d 100644 --- a/aten/src/ATen/native/cuda/AmpKernels.cu +++ b/aten/src/ATen/native/cuda/AmpKernels.cu @@ -59,7 +59,7 @@ void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad, auto* found_inf_ptr = found_inf.data_ptr(); auto* inv_scale_ptr = inv_scale.data_ptr(); - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; gpu_kernel(iter, [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { @@ -154,7 +154,7 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads, auto* found_inf_ptr = found_inf.data_ptr(); auto* inv_scale_ptr = inv_scale.data_ptr(); - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. multi_tensor_apply<1>(tensor_lists, diff --git a/aten/src/ATen/native/cuda/AveragePool2d.cu b/aten/src/ATen/native/cuda/AveragePool2d.cu index 5de3adc08bee8..df9fcfef64167 100644 --- a/aten/src/ATen/native/cuda/AveragePool2d.cu +++ b/aten/src/ATen/native/cuda/AveragePool2d.cu @@ -232,30 +232,31 @@ __global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads, } // anonymous namespace -TORCH_IMPL_FUNC(avg_pool2d_out_cuda) ( - const Tensor& input_, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override, - const Tensor& output -) { +TORCH_IMPL_FUNC(avg_pool2d_out_cuda) +(const Tensor& input_, + int64_t kH_, + int64_t kW_, + int64_t dH_, + int64_t dW_, + int64_t padH_, + int64_t padW_, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + const Tensor& output) { TensorArg output_arg{ output, "output", 1 }; TensorArg input_arg{ input_, "input_", 2 }; checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg}); - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); + const int kH = safe_downcast(kH_); + const int kW = safe_downcast(kW_); - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); + const int dH = safe_downcast(dH_); + const int dW = safe_downcast(dW_); - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + const int padH = safe_downcast(padH_); + const int padW = safe_downcast(padW_); /* sizes */ const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1; @@ -263,8 +264,8 @@ TORCH_IMPL_FUNC(avg_pool2d_out_cuda) ( const int64_t inputHeight = input_.size(-2); const int64_t inputWidth = input_.size(-1); - const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); - const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); const auto memory_format = input_.suggest_memory_format(); Tensor input = input_.contiguous(memory_format); @@ -289,37 +290,55 @@ TORCH_IMPL_FUNC(avg_pool2d_out_cuda) ( case MemoryFormat::ChannelsLast: { output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast); avg_pool2d_out_cuda_frame_nhwc - <<>>( - count, - input_data, - nbatch, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, - dH, dW, - padH, padW, - output_data, - divisor_override_value, - count_include_pad, use_divisor); + <<>>( + count, + input_data, + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output_data, + divisor_override_value, + count_include_pad, + use_divisor); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } case MemoryFormat::Contiguous: { avg_pool2d_out_cuda_frame - <<>>( - count, - input_data, - nbatch, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, - dH, dW, - padH, padW, - output_data, - divisor_override_value, - count_include_pad, use_divisor); + <<>>( + count, + input_data, + nbatch, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output_data, + divisor_override_value, + count_include_pad, + use_divisor); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 671b354734db0..6c712af93cc68 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -210,7 +211,7 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( int dT, int dH, int dW, int padT, int padH, int padW, bool count_include_pad, - int offsetZ, int divisor_override) + int offsetZ, int divisor_override, const int gradInput_numel) { int oCol = blockIdx.x * blockDim.x + threadIdx.x; int oRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -253,7 +254,8 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( { for (int iCol = wstart; iCol < wend; ++iCol) { - gpuAtomicAddNoReturn(&gradInput[slice][iFrame][iRow][iCol], val); + const int index = slice * gradInput.stride(0) + iFrame * gradInput.stride(1) + iRow * gradInput.stride(2) + iCol * gradInput.stride(3); + fastAtomicAdd(gradInput.data(), index, gradInput_numel, val, true); } } } @@ -568,7 +570,7 @@ TORCH_IMPL_FUNC(avg_pool3d_backward_out_cuda) ( dT, dH, dW, padT, padH, padW, count_include_pad, - offsetZ, divisor); + offsetZ, divisor, work_grad_input.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp similarity index 96% rename from aten/src/ATen/native/cuda/BatchLinearAlgebra.cu rename to aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index 0dae7a2aa3c11..7fdc55d818084 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -1701,7 +1701,7 @@ static void cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) #endif // USE_CUSOLVER } -REGISTER_DISPATCH(cholesky_stub, &cholesky_kernel) +REGISTER_CUDA_DISPATCH(cholesky_stub, &cholesky_kernel) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1773,7 +1773,7 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) } -REGISTER_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1945,7 +1945,7 @@ static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& in } } -REGISTER_DISPATCH(lu_stub, &apply_lu); +REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2039,7 +2039,7 @@ void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bo } } -REGISTER_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_CUDA_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2057,7 +2057,7 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) { #endif } -REGISTER_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_CUDA_DISPATCH(orgqr_stub, &orgqr_kernel_impl); void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) { #if defined(USE_CUSOLVER) @@ -2069,7 +2069,7 @@ void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, b #endif } -REGISTER_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_CUDA_DISPATCH(ormqr_stub, &ormqr_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2148,7 +2148,7 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { } } -REGISTER_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel); template static void apply_qr(Tensor& Q, Tensor& R, int64_t q_size_minus_2, int64_t r_size_minus_1, int64_t n_columns, @@ -2423,7 +2423,7 @@ void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, c #endif } -REGISTER_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2513,7 +2513,7 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector return std::tuple(out_eigvals, out_eigvecs); } -REGISTER_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_CUDA_DISPATCH(eig_stub, &eig_kernel_impl); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2599,7 +2599,7 @@ void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, }); } -REGISTER_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2671,8 +2671,7 @@ AT_ERROR("svd: MAGMA library not found in " std::tuple _svd_helper_cuda_legacy(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); - int64_t m = self.size(-2), n = self.size(-1); - int64_t k = std::min(m, n); + int64_t m = self.size(-2); char jobchar = compute_uv ? (some ? 'S' : 'A') : 'N'; @@ -2922,13 +2921,13 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten } } -REGISTER_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_dispatch); +REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_dispatch); static void lu_solve_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots) { lu_solve_trans_dispatch(b, lu, pivots, 'N'); } -REGISTER_DISPATCH(lu_solve_stub, &lu_solve_dispatch); +REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_dispatch); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -3112,7 +3111,85 @@ void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& /*rank*/, Tensor& /*singul } } -REGISTER_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_CUDA_DISPATCH(lstsq_stub, &lstsq_kernel); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ legacy_lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +std::tuple legacy_lstsq_cuda(const Tensor &B, const Tensor &A) { + TORCH_WARN_ONCE( + "torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n", + "torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in " + "the returned tuple (although it returns other information about the problem).\n", + "To get the qr decomposition consider using torch.linalg.qr.\n", + "The returned solution in torch.lstsq stored the residuals of the solution in the ", + "last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the ", + "residuals in the field 'residuals' of the returned named tuple.\n", + "The unpacking of the solution, as in\n", + "X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n", + "should be replaced with\n", + "X = torch.linalg.lstsq(A, B).solution" + ); + +#ifndef USE_MAGMA + TORCH_CHECK(false, "solve: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else + const auto dtype = A.scalar_type(); + TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ", + dtype, " and ", B.scalar_type()); + TORCH_CHECK(A.numel() > 0 && A.dim() == 2, "A should be (non-empty) 2 dimensional"); + TORCH_CHECK(B.numel() > 0 && B.dim() == 2, "B should be (non-empty) 2 dimensional"); + auto a_sizes = A.sizes(); + auto b_sizes = B.sizes(); + TORCH_CHECK(a_sizes[0] == b_sizes[0], "Expected A and b to have same size " + "at dim 0, but A has ", a_sizes[0], " rows and B has ", b_sizes[0], " rows"); + TORCH_CHECK(a_sizes[0] >= a_sizes[1], "Expected A with shape (m x n) to have " + "m >= n. The case for m < n is not implemented yet."); + + Tensor A_working = cloneBatchedColumnMajor(A); + Tensor B_working = cloneBatchedColumnMajor(B); + + int64_t m = a_sizes[0]; + int64_t n = a_sizes[1]; + int64_t nrhs = b_sizes[1]; + + int info; + AT_DISPATCH_FLOATING_TYPES(A.scalar_type(), "legacy_lstsq_cuda", [&] { + scalar_t *a_data = A_working.data_ptr(); + scalar_t *b_data = B_working.data_ptr(); + scalar_t wkopt; + magmaGels(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); + + const auto hwork_size = static_cast(wkopt); + scalar_t *hwork = nullptr; + ALLOCATE_ARRAY(hwork, scalar_t, hwork_size); + + magmaGels(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, hwork_size, &info); + }); + + TORCH_CHECK(info == 0, "MAGMA gels : Argument %d : illegal value", -info); + return std::tuple(B_working, A_working); +#endif // USE_MAGMA +} + +std::tuple legacy_lstsq_out_cuda( + const Tensor& B, const Tensor& A, Tensor& B_out, Tensor& A_out) { + const auto dtype = A.scalar_type(); + TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ", + A.scalar_type(), " and ", B.scalar_type()); + TORCH_CHECK(A_out.scalar_type() == dtype, "A_out to have scalar type ", dtype, + " but found", A_out.scalar_type()); + TORCH_CHECK(B_out.scalar_type() == dtype, "A_out to have scalar type ", dtype, + " but found", B_out.scalar_type()); + Tensor A_tmp, B_tmp; + std::tie(B_tmp, A_tmp) = native::legacy_lstsq_cuda(B, A); + resize_output(A_out, A_tmp.sizes()); + A_out.copy_(A_tmp); + resize_output(B_out, B_tmp.sizes()); + B_out.copy_(B_tmp); + return std::tuple(B_out, A_out); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp similarity index 98% rename from aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu rename to aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp index bb9af142955f0..13d67e571e7dc 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp @@ -143,10 +143,6 @@ static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, boo cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; cublasSideMode_t side = CUBLAS_SIDE_LEFT; - auto A_data = A.data_ptr(); - auto B_data = B.data_ptr(); - auto A_mat_stride = matrixStride(A); - auto B_mat_stride = matrixStride(B); auto batch_size = cuda_int_cast(batchCount(A), "batch_size"); auto m = cuda_int_cast(A.size(-2), "m"); auto n = cuda_int_cast(A.size(-1), "n"); @@ -329,8 +325,6 @@ Tensor& _linalg_inv_out_helper_cuda_lib(Tensor& result, Tensor& infos_getrf, Ten result.zero_(); result.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1); - const int batch_size = cuda_int_cast(batchCount(result), "batchCount"); - if (result.dim() > 2) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_inv_out_cuda", [&]{ apply_batched_inverse_lib( @@ -435,10 +429,6 @@ inline static void _apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, T auto U_data = U.data_ptr(); auto S_data = S.data_ptr(); auto VT_data = VT.data_ptr(); - auto self_stride = matrixStride(self); - auto U_stride = matrixStride(U); - auto S_stride = S.size(-1); - auto VT_stride = matrixStride(VT); int batchsize = cuda_int_cast(batchCount(self), "batch size"); int m = cuda_int_cast(self.size(-2), "m"); @@ -481,7 +471,6 @@ std::tuple _svd_helper_cuda_lib(const Tensor& self, bool at::Tensor infos = at::zeros({batch_size}, self.options().dtype(at::kInt)); const int64_t m = self.size(-2); const int64_t n = self.size(-1); - const int64_t k = std::min(m, n); Tensor U_working_copy, S_working_copy, VT_working_copy; std::tie(U_working_copy, S_working_copy, VT_working_copy) = \ @@ -686,11 +675,7 @@ inline static void apply_cholesky_cusolver_potrsBatched(Tensor& self_working_cop const int64_t nrhs = self_working_copy.size(-1); const int64_t lda = std::max(1, n); const int64_t batch_size = batchCount(self_working_copy); - const int64_t self_matrix_stride = matrixStride(self_working_copy); - scalar_t* self_working_copy_ptr = self_working_copy.data_ptr(); - const scalar_t* A_ptr = A_column_major_copy.data_ptr(); - const int64_t A_matrix_stride = matrixStride(A_column_major_copy); const int64_t ldb = std::max(1, A_column_major_copy.size(-1)); int* infos_ptr = infos.data_ptr(); @@ -882,8 +867,6 @@ void geqrf_cusolver(const Tensor& input, const Tensor& tau) { */ template static void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) { - using value_t = typename c10::scalar_value_type::type; - auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; auto trans = transpose ? (input.is_complex() ? CUBLAS_OP_C : CUBLAS_OP_T) : CUBLAS_OP_N; @@ -957,7 +940,6 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, */ template inline static void apply_orgqr(Tensor& self, const Tensor& tau) { - using value_t = typename c10::scalar_value_type::type; auto self_data = self.data_ptr(); auto tau_data = tau.data_ptr(); auto self_matrix_stride = matrixStride(self); diff --git a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu index a07fd663581fe..b1c76e119a78a 100644 --- a/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryAddSubKernel.cu @@ -10,53 +10,20 @@ namespace at { namespace native { -template +template struct AddFunctor { - AddFunctor(accscalar_t a): alpha(a) {} - __device__ __forceinline__ scalar_t operator() (const scalar_t a, const scalar_t b) const { - return a + alpha * b; + AddFunctor(T alpha) : alpha_(alpha) {} + T alpha_; + __device__ __forceinline__ T operator()(T a, T b) const __ubsan_ignore_undefined__ { + return a + b * alpha_; } - private: - accscalar_t alpha; -}; - -template -struct AddScalarFunctor { - static_assert(SCALAR_ARG == 1 || SCALAR_ARG == 2, "SCALAR_ARG must be either 1 or 2"); - AddScalarFunctor(accscalar_t alpha, accscalar_t b): alpha(alpha), b(b) {} - __device__ __forceinline__ scalar_t operator() (const scalar_t a) const { - return static_cast(SCALAR_ARG == 1 ? b + alpha * a : a + alpha * b); - } - private: - accscalar_t alpha; - accscalar_t b; }; void add_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { - if (!isIntegralType(iter.common_dtype(), /* includeBool */ true) && (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) { - // if common dtype is half the scalar constant can overflow in half precision, and yet the result can - // still be representable in the half dtype. Cast scalar to acc_type to have better accuracy. - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { - using accscalar_t = at::acc_type; - int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2; - auto b = iter.scalar_value(scalar_arg); - iter.remove_operand(scalar_arg); - const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1))); - if (scalar_arg == 1) { - AddScalarFunctor f(alpha_scalar.to(), b); - gpu_kernel(iter, f); - } else { - AddScalarFunctor f(alpha_scalar.to(), b); - gpu_kernel(iter, f); - } - }); - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { - using accscalar_t = at::acc_type; - AddFunctor f(alpha_scalar.to()); - gpu_kernel_with_scalars(iter, f); - }); - } + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, AddFunctor(alpha_scalar.to())); + }); } static void sub_kernel_cuda(TensorIteratorBase& iter, const Scalar& alpha_scalar) { diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu index da615fe12221b..e6a5300780e57 100644 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu @@ -14,16 +14,6 @@ namespace at { namespace native { -template -struct MulScalarFunctor { - MulScalarFunctor(accscalar_t b_): b(b_) {} - __device__ scalar_t operator() (scalar_t a) const { - return a * b; - } - private: - accscalar_t b; -}; - template struct DivFunctor { __device__ scalar_t operator() (scalar_t a, scalar_t b) const { @@ -31,9 +21,9 @@ struct DivFunctor { } }; -template +template struct MulFunctor { - __device__ scalar_t operator() (scalar_t a, scalar_t b) const { + __device__ T operator() (T a, T b) const { return a * b; } }; @@ -53,11 +43,11 @@ void div_true_kernel_cuda(TensorIteratorBase& iter) { // scalar, compute a * reciprocal(b). Note that this may lose one bit of // precision compared to computing the division. AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() { - using accscalar_t = at::acc_type; - auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); + using opmath_t = at::opmath_type; + auto inv_b = opmath_t(1.0) / iter.scalar_value(2); iter.remove_operand(2); - MulScalarFunctor f(inv_b); - gpu_kernel(iter, f); + gpu_kernel(iter, BUnaryFunctor>( + MulFunctor(), inv_b)); }); } else { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "div_true_cuda", [&]() { @@ -180,25 +170,10 @@ void div_floor_kernel_cuda(TensorIteratorBase& iter) { } void mul_kernel_cuda(TensorIteratorBase& iter) { - if (!isIntegralType(iter.common_dtype(), /*includeBool*/ true) && - (iter.is_cpu_scalar(1) || iter.is_cpu_scalar(2))) { - //if common dtype is half the scalar constant can overflow in half precision, and yet the result can - //still be representable in the half dtype. Cast scalar to acc_type to have better accuracy - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "mul_cuda", [&]() { - using accscalar_t = at::acc_type; - int scalar_arg = iter.is_cpu_scalar(1) ? 1 : 2; - auto b = iter.scalar_value(scalar_arg); - iter.remove_operand(scalar_arg); - const cuda::OptionalCUDAGuard device_guard(device_of(iter.tensor(1))); - MulScalarFunctor f(b); - gpu_kernel(iter, f); - }); - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { - MulFunctor f; - gpu_kernel_with_scalars(iter, f); - }); - } + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, MulFunctor()); + }); } REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index b0fe0ac7a05b6..269307d605aec 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -4,24 +4,51 @@ #include #include - namespace at { namespace native { namespace { +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { if (tensor.is_non_overlapping_and_dense()) { // common case transpose_tensor = tensor.is_contiguous(); - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } IntArrayRef tensor_strides = tensor.strides(); IntArrayRef tensor_sizes = tensor.sizes(); if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { transpose_tensor = false; - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { transpose_tensor = true; - return c10::MaybeOwned::borrowed(tensor); + return resolve_conj_if_indicated(tensor, true); } else { transpose_tensor = true; return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); @@ -39,19 +66,19 @@ c10::MaybeOwned prepare_batch_matrix_for_cublas(const Tensor& tensor, bo if (tensor_strides[fast_dim] == 1 && (tensor_strides[leading_dim] >= std::max(1, m))) { transpose_tensor = false; - tensor_ = c10::MaybeOwned::borrowed(tensor); - ld_tensor = tensor_strides[leading_dim]; + tensor_ = resolve_conj_if_indicated(tensor, true); + ld_tensor = tensor_->strides()[leading_dim]; } else if ((tensor_strides[leading_dim] == 1) && (tensor_strides[fast_dim] >= std::max(1, n))) { transpose_tensor = true; - tensor_ = c10::MaybeOwned::borrowed(tensor); - ld_tensor = tensor_strides[fast_dim]; + tensor_ = resolve_conj_if_indicated(tensor, false); + ld_tensor = tensor_->strides()[fast_dim]; } else { transpose_tensor = !transpose_result; // gemm call requires leading dimension and stride parameters to be non-zero bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0; if (tensor.is_contiguous() && is_stride_non_zero) { - tensor_ = c10::MaybeOwned::borrowed(tensor); + tensor_ = resolve_conj_if_indicated(tensor, transpose_result); } else { tensor_ = c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); } @@ -104,8 +131,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma c10::MaybeOwned result_ = prepare_matrix_for_cublas(result, transpose_result); bool transpose_mat1; bool transpose_mat2; - c10::MaybeOwned mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1); - c10::MaybeOwned mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2); + auto mat1_ = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); + auto mat2_ = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); if (transpose_result) { transpose_mat1 = !transpose_mat1; @@ -141,6 +168,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma c10::nullopt /* pin_memory */)); } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj()); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] { scalar_t alpha_val = alpha.to(); scalar_t beta_val = beta.to(); @@ -148,8 +177,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_t* mat2_ptr = mat2_->data_ptr(); scalar_t* result_ptr = result_->data_ptr(); at::cuda::blas::gemm( - transpose_mat1 ? 't' : 'n', - transpose_mat2 ? 't' : 'n', + transpose_mat1 ? mat1_->is_conj() ? 'c' : 't' : 'n', + transpose_mat2 ? mat2_->is_conj() ? 'c' : 't' : 'n', m, n, k, alpha_val, mat1_ptr, mat1_ld, @@ -207,11 +236,11 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& if ((result_strides[1] == 1) && ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[1])))) { - result_ = c10::MaybeOwned::borrowed(result); + result_ = resolve_conj_if_indicated(result, true); } else if ((result_strides[2] == 1) && (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { transpose_result = true; - result_ = c10::MaybeOwned::borrowed(result); + result_ = resolve_conj_if_indicated(result, true); } else { result_ = c10::MaybeOwned::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2)); } @@ -230,6 +259,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ldc = result_->strides()[leading_dim]; int64_t num_batches = result_->sizes()[0]; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj()); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] { scalar_t alpha_val = alpha.to(); scalar_t beta_val = beta.to(); @@ -237,8 +268,8 @@ Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& scalar_t* batch2_ptr = batch2_->data_ptr(); scalar_t* result_ptr = result_->data_ptr(); at::cuda::blas::bgemm( - transpose_batch1 ? 't' : 'n', - transpose_batch2 ? 't' : 'n', + transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n', + transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n', m, n, k, alpha_val, batch1_ptr, lda, batch1_->strides()[0], @@ -353,8 +384,19 @@ inline void dot_check(const Tensor& self, const Tensor& other) { } // anonymous namespace Tensor dot_cuda(const Tensor& self, const Tensor& other) { - at::NoNamesGuard guard; + if (self.is_complex()) { + if (self.is_conj()) { + if (other.is_conj()) { + return (dot_cuda(self.conj(), other.conj())).conj(); + } else { + return vdot_cuda(self.conj(), other); + } + } else if (other.is_conj()) { + return vdot_cuda(other.conj(), self); + } + } + at::NoNamesGuard guard; dot_check(self, other); const int n = static_cast(self.numel()); @@ -391,6 +433,16 @@ Tensor vdot_cuda(const Tensor& self, const Tensor& other) { return dot_cuda(self, other); } + if (self.is_conj()) { + if (other.is_conj()) { + return vdot_cuda(other.conj(), self.conj()); + } else { + return dot_cuda(self.conj(), other); + } + } else if (other.is_conj()) { + return (dot_cuda(self, other.conj())).conj(); + } + at::NoNamesGuard guard; dot_check(self, other); diff --git a/aten/src/ATen/native/cuda/ConvolutionMM2d.cu b/aten/src/ATen/native/cuda/ConvolutionMM2d.cu index ede7e1fb39b29..bf3f8ac0a6eff 100644 --- a/aten/src/ATen/native/cuda/ConvolutionMM2d.cu +++ b/aten/src/ATen/native/cuda/ConvolutionMM2d.cu @@ -1,12 +1,482 @@ #include -#include -namespace at { -namespace native { +#include +#include +#include +#include +#include +#include -std::tuple slow_conv2d_backward_out_cuda(const Tensor& grad_output, - const Tensor& self, - const Tensor& weight, +namespace at { namespace native { +namespace { + +void slow_conv2d_shape_check( + const Tensor& input, const Tensor& grad_output, + const Tensor& weight, const Tensor& bias, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW, + bool weight_nullable) { + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: ", kH, " kW: ", kW); + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: ", dH, " dW: ", dW); + + TORCH_CHECK(weight_nullable || weight.defined(), + "weight tensor is expected to be non-nullable"); + TORCH_CHECK(!weight.defined() || + ((weight.numel() > 0) && (weight.dim() == 2)), + "non-empty 2D weight tensor expected, but got: ", weight.sizes()); + TORCH_CHECK(!bias.defined() || (bias.dim() == 1 && bias.sizes()[0] == weight.sizes()[0]), + "Expected bias to have shape [", weight.sizes()[0], "] but got ", bias.sizes()); + + const auto in_sizes = input.sizes(); + constexpr int ndim = 4; + constexpr int dimf = 1; + constexpr int dimh = 2; + constexpr int dimw = 3; + TORCH_CHECK(in_sizes.size() == ndim, "Expected 4D input tensor, but got ", in_sizes); + + // Allow for empty batch size but not other dimensions + const bool valid_empty = c10::multiply_integers(in_sizes.slice(1)) != 0; + TORCH_CHECK(valid_empty, "non-empty input tensor expected but got: ", in_sizes); + + int64_t inputHeight = in_sizes[dimh]; + int64_t inputWidth = in_sizes[dimw]; + + int64_t exactInputHeight = inputHeight + 2 * padH; + int64_t exactInputWidth = inputWidth + 2 * padW; + + TORCH_CHECK(exactInputHeight >= kH && exactInputWidth >= kW, + "Calculated padded input size per channel: ", + IntArrayRef{exactInputHeight, exactInputWidth}, + ". Kernel size: ", IntArrayRef{kH, kW}, + ". Kernel size can't be greater than actual input size"); + + // NOTE: can't use conv_output_size if the weight isn't defined + auto outputHeight = div_rtn(exactInputHeight - kH, dH) + 1; + auto outputWidth = div_rtn(exactInputWidth - kW, dW) + 1; + + TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, + "Given input size per channel: ", + IntArrayRef{inputHeight, inputWidth}, + ". Calculated output size per channel: ", + IntArrayRef{outputHeight, outputWidth}, + ". Output size is too small"); + + if (weight.defined()) { + const auto w_sizes = weight.sizes(); + int64_t nInputPlane = w_sizes[1]; + if (w_sizes.size() == 2) { + nInputPlane /= (kH * kW); + } + TORCH_CHECK(in_sizes[dimf] == nInputPlane, + "Expected input dim ", dimf, " to have size ", nInputPlane, + " but got ", in_sizes[dimf]); + } + + if (grad_output.defined()) { + const auto gO_sizes = grad_output.sizes(); + TORCH_CHECK(gO_sizes.size() == ndim, + "Expected grad_output to have ", ndim, + " dimensions but got shape", gO_sizes); + + if (weight.defined()) { + const auto w_sizes = weight.sizes(); + TORCH_CHECK(gO_sizes[dimf] == w_sizes[0], + "Expected dim ", dimf, " to have size ", w_sizes[0], + " but got ", gO_sizes[dimf]); + } else if (bias.defined()) { + const auto b_sizes = bias.sizes(); + int64_t nOutputPlane = b_sizes.size() == 0 ? 1 : b_sizes[0]; + TORCH_CHECK(gO_sizes[dimf] == nOutputPlane, + "Expected grad_output dim ", dimf, " to have size ", + nOutputPlane, " but got ", gO_sizes[dimf]); + } + TORCH_CHECK(gO_sizes[dimh] == outputHeight, + "Expected grad_output dim ", dimh, " to have size ", + outputHeight, " but got ", gO_sizes[dimh]); + TORCH_CHECK(gO_sizes[dimw] == outputWidth, + "Expected grad_output dim ", dimw, " to have size ", + outputWidth, " but got ", gO_sizes[dimw]); + } +} + +Tensor new_view_weight_MM2d(const Tensor& weight_) { + auto weight = weight_.expect_contiguous(); + const auto w_sizes = weight->sizes(); + TORCH_CHECK(w_sizes.size() == 4); + int64_t s1 = w_sizes[0]; + int64_t s2 = c10::multiply_integers(w_sizes.slice(1)); + return weight->view({s1, s2}); +} + +void slow_conv2d_forward( + const Tensor &input, + const Tensor &output, + const Tensor &weight_, + const Tensor &bias, + const Tensor &columns, + const Tensor &ones_, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW) { + auto weight = new_view_weight_MM2d(weight_); + slow_conv2d_shape_check( + input, {}, weight, bias, kH, kW, dH, dW, padH, padW, /*weight_nullable*/false); + + TORCH_CHECK(!bias.defined() || bias.is_contiguous(), + "bias tensor has to be contiguous"); + + constexpr int ndim = 4; + constexpr int dimf = 1; + constexpr int dimh = 2; + constexpr int dimw = 3; + + auto in_sizes = input.sizes(); + int64_t batchSize = in_sizes[0]; + int64_t nInputPlane = in_sizes[dimf]; + int64_t inputHeight = in_sizes[dimh]; + int64_t inputWidth = in_sizes[dimw]; + int64_t nOutputPlane = weight.sizes()[0]; + int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; + int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; + + // Resize output + resize_output(output, {batchSize, nOutputPlane, outputHeight, outputWidth}); + + // Resize temporary columns + resize_output(columns, {nInputPlane * kW * kH, outputHeight * outputWidth}); + + // Define a buffer of ones, for bias accumulation + // Note: this buffer can be shared with other modules, it only ever gets increased, + // and always contains ones. + Tensor ones; + if (bias.defined()) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + const bool requires_columns = ( + kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_cuda", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix mulitply per output: + auto input_n = input.select(0, elt); + auto output_n = output.select(0, elt); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m_ = nOutputPlane; + int64_t n_ = outputHeight * outputWidth; + int64_t k_ = 1; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + if (bias.defined()) { + at::cuda::blas::gemm( + 't', 'n', + n_, m_, k_, + scalar_t(1), + ones.data_ptr(), k_, + bias.data_ptr(), k_, + scalar_t(0), + output_n.data_ptr(), n_ + ); + } else { + output_n.zero_(); + } + + if (requires_columns) { + // Extract columns: + at::native::im2col( + c10::cuda::getCurrentCUDAStream(), + input_n.data_ptr(), + nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, + 1, 1, + columns.data_ptr() + ); + } + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nOutputPlane; + int64_t n = columns.size(1); + int64_t k = nInputPlane*kH*kW; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + auto gemm_in_ptr = requires_columns ? + columns.data_ptr() : + input_n.data_ptr(); + at::cuda::blas::gemm( + 'n', 'n', + n, m, k, + scalar_t(1), + gemm_in_ptr, n, + weight.data_ptr(), k, + scalar_t(1), + output_n.data_ptr(), n + ); + } + }); +} + +void slow_conv2d_backward( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_input, + const Tensor &weight_, + const Tensor &grad_columns, + const Tensor &ones, + int kH, int kW, + int dH, int dW, + int padH, int padW) { + Tensor weight = new_view_weight_MM2d(weight_); + slow_conv2d_shape_check(input, grad_output, weight, {}, + kH, kW, dH, dW, padH, padW, /*weight_nullable=*/false); + + // Params + auto weight_sizes = weight.sizes(); + int nInputPlane = weight_sizes[1]/(kW*kH); + int nOutputPlane = weight_sizes[0]; + + TORCH_INTERNAL_ASSERT(grad_output.is_contiguous()); + + auto input_sizes = input.sizes(); + int64_t inputWidth = input_sizes[3]; + int64_t inputHeight = input_sizes[2]; + auto output_sizes = grad_output.sizes(); + int64_t outputWidth = output_sizes[3]; + int64_t outputHeight = output_sizes[2]; + + // Batch size + input planes + int64_t batchSize = input_sizes[0]; + + // Resize output + resize_output(grad_input, input_sizes); + TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); + + // Resize temporary columns + resize_output(grad_columns, {nInputPlane*kW*kH, outputHeight*outputWidth}); + TORCH_CHECK(grad_columns.is_contiguous(), "grad_columns must be contiguous"); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_backward_cuda", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix mulitply per sample: + auto grad_input_n = grad_input.select(0, elt); + auto grad_output_n = grad_output.select(0, elt); + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nInputPlane*kW*kH; + int64_t n = grad_columns.sizes()[1]; + int64_t k = nOutputPlane; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + at::cuda::blas::gemm( + 'n', 't', + n, m, k, + scalar_t(1), + grad_output_n.data_ptr(), n, + weight.data_ptr(), m, + scalar_t(0), + grad_columns.data_ptr(), n + ); + + // Unpack columns back into input: + using acc_t = at::acc_type; + at::native::col2im( + c10::cuda::getCurrentCUDAStream(), + grad_columns.data_ptr(), + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, + 1, 1, grad_input_n.data_ptr() + ); + } + }); +} + +void slow_conv2d_grad_weight_bias( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_weight_, + const Tensor &grad_bias, + const Tensor &columns, + const Tensor &ones, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW) { + if (grad_weight_.defined()) { + TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous"); + } + if (grad_bias.defined()) { + TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous"); + TORCH_CHECK(ones.is_contiguous(), "ones needs to be contiguous"); + } + + auto grad_weight = new_view_weight_MM2d(grad_weight_); + slow_conv2d_shape_check(input, grad_output, grad_weight, grad_bias, + kH, kW, dH, dW, padH, padW, /*weight_nullable=*/true); + + // Params + TORCH_INTERNAL_ASSERT(input.is_contiguous()); + TORCH_INTERNAL_ASSERT(grad_output.is_contiguous()); + + auto input_sizes = input.sizes(); + int64_t nInputPlane = input_sizes[1]; + int64_t nOutputPlane = grad_output.sizes()[1]; + + int64_t inputWidth = input_sizes[3]; + int64_t inputHeight = input_sizes[2]; + int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; + int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; + + // Batch size + input planes + int64_t batchSize = input_sizes[0]; + + // Define a buffer of ones, for bias accumulation + if (ones.defined() && ones.numel() < outputHeight * outputWidth) { + ones.resize_({outputHeight, outputWidth}); + ones.fill_(1); + } + + // Resize temporary columns + resize_output(columns, {nInputPlane * kH * kW, outputHeight * outputWidth}); + + const bool requires_columns = ( + kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_grad_weight_bias_cuda", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix mulitply per output: + auto grad_output_n = grad_output.select(0, elt); + + // Do Weight: + if (grad_weight.defined()) { + // Matrix mulitply per output: + auto input_n = input.select(0, elt); + + if (requires_columns) { + // Extract columns: + at::native::im2col( + c10::cuda::getCurrentCUDAStream(), + input_n.data_ptr(), + nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, + 1, 1, + columns.data_ptr() + ); + } + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nOutputPlane; + int64_t n = nInputPlane*kW*kH; + int64_t k = columns.sizes()[1]; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + auto gemm_in_ptr = requires_columns ? + columns.data_ptr() : + input_n.data_ptr(); + at::cuda::blas::gemm( + 't', 'n', + n, m, k, + scalar_t(1), + gemm_in_ptr, k, + grad_output_n.data_ptr(), k, + scalar_t(1), + grad_weight.data_ptr(), n + ); + } + + // Do Bias: + if (grad_bias.defined()) { + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m_ = nOutputPlane; + int64_t k_ = outputHeight * outputWidth; + + // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices) + at::cuda::blas::gemv( + 't', + k_, m_, + scalar_t(1), + grad_output_n.data_ptr(), k_, + ones.data_ptr(), 1, + scalar_t(1), + grad_bias.data_ptr(), 1 + ); + } + } + }); +} + +} // namespace (anonymous) + + +std::tuple slow_conv2d_forward_out_cuda( + const Tensor &self_, + const Tensor &weight_, + IntArrayRef kernel_size, + const c10::optional &bias_, + IntArrayRef stride, + IntArrayRef padding, + Tensor &output, + Tensor &finput, + Tensor &fgrad_input) { + TORCH_CHECK(kernel_size.size() == 2); + TORCH_CHECK(stride.size() == 2); + TORCH_CHECK(padding.size() == 2); + + auto self = self_.expect_contiguous(); + auto weight = weight_.expect_contiguous(); + auto bias = [&] { + if (bias_.has_value() && bias_->defined()) { + return bias_->expect_contiguous(); + } + return MaybeOwned::owned(c10::in_place); + }(); + + slow_conv2d_forward( + *self, + output, + *weight, + *bias, + finput, + fgrad_input, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1] + ); + return std::tuple{ + output, finput, fgrad_input}; +} + +std::tuple slow_conv2d_forward_cuda( + const Tensor &self, + const Tensor &weight, + IntArrayRef kernel_size, + const c10::optional &bias, + IntArrayRef stride, + IntArrayRef padding) { + auto output = at::empty({0}, self.options()); + auto finput = at::empty({0}, self.options()); + auto fgrad_input = at::empty({0}, self.options()); + return slow_conv2d_forward_out_cuda( + self, weight, kernel_size, bias, stride, padding, output, finput, fgrad_input); +} + +std::tuple slow_conv2d_backward_out_cuda( + const Tensor& grad_output_, + const Tensor& self_, + const Tensor& weight_, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, @@ -16,17 +486,42 @@ std::tuple slow_conv2d_backward_out_cuda(const Tensor Tensor& grad_weight, Tensor& grad_bias) { if (grad_weight.defined()) { - grad_weight.resize_(weight.sizes()); + resize_output(grad_weight, weight_.sizes()); grad_weight.zero_(); } if (grad_bias.defined()) { - grad_bias.resize_({ weight.size(0) }); + resize_output(grad_bias, {weight_.sizes()[0]}); grad_bias.zero_(); } - return legacy::cuda::_thnn_conv2d_backward_out(grad_input, grad_weight, grad_bias, - grad_output, self, weight, - kernel_size, stride, padding, - finput, fgrad_input); + auto grad_output = grad_output_.expect_contiguous(); + if (grad_input.defined()) { + resize_output(grad_input, self_.sizes()); + auto weight = weight_.expect_contiguous(); + + slow_conv2d_backward( + self_, *grad_output, + grad_input, *weight, + finput, fgrad_input, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1]); + } + if (grad_weight.defined() || grad_bias.defined()) { + auto self = self_.expect_contiguous(); + slow_conv2d_grad_weight_bias( + *self, + *grad_output, + grad_weight, + grad_bias, + finput, + fgrad_input, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1] + ); + } + return std::tuple{ + grad_input, grad_weight, grad_bias}; } std::tuple slow_conv2d_backward_cuda( diff --git a/aten/src/ATen/native/cuda/DistributionBernoulli.cu b/aten/src/ATen/native/cuda/DistributionBernoulli.cu index 3acf87c3c4b40..e113d82c0f5c7 100644 --- a/aten/src/ATen/native/cuda/DistributionBernoulli.cu +++ b/aten/src/ATen/native/cuda/DistributionBernoulli.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionCauchyKernel.cu b/aten/src/ATen/native/cuda/DistributionCauchyKernel.cu index 35a1e6ef5a98c..b33ee792ea4cc 100644 --- a/aten/src/ATen/native/cuda/DistributionCauchyKernel.cu +++ b/aten/src/ATen/native/cuda/DistributionCauchyKernel.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionExponentialKernel.cu b/aten/src/ATen/native/cuda/DistributionExponentialKernel.cu index b4cf288bcb7b8..f28a910e9980b 100644 --- a/aten/src/ATen/native/cuda/DistributionExponentialKernel.cu +++ b/aten/src/ATen/native/cuda/DistributionExponentialKernel.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionGeometricKernel.cu b/aten/src/ATen/native/cuda/DistributionGeometricKernel.cu index eb71ab3231f12..6cafba0dcbe78 100644 --- a/aten/src/ATen/native/cuda/DistributionGeometricKernel.cu +++ b/aten/src/ATen/native/cuda/DistributionGeometricKernel.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionLogNormalKernel.cu b/aten/src/ATen/native/cuda/DistributionLogNormalKernel.cu index 89b9c04b3a687..c5da3bdf92d2a 100644 --- a/aten/src/ATen/native/cuda/DistributionLogNormalKernel.cu +++ b/aten/src/ATen/native/cuda/DistributionLogNormalKernel.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionNormal.cu b/aten/src/ATen/native/cuda/DistributionNormal.cu index da647277c1762..1b2dd19eec0d1 100644 --- a/aten/src/ATen/native/cuda/DistributionNormal.cu +++ b/aten/src/ATen/native/cuda/DistributionNormal.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/DistributionRandomKernel.cu b/aten/src/ATen/native/cuda/DistributionRandomKernel.cu index 8d6614b9010d8..ea2aaad9445b2 100644 --- a/aten/src/ATen/native/cuda/DistributionRandomKernel.cu +++ b/aten/src/ATen/native/cuda/DistributionRandomKernel.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index cf1281d320b14..81f8fe8fa227f 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -16,10 +16,8 @@ #include #include #include -#include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 100ffbd99388c..ba79fa10f926a 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -224,14 +224,16 @@ __global__ void renorm_kernel( template void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); -Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices, +Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { auto grad_arg = TensorArg(grad_, "grad", 1); - auto indices_arg = TensorArg(indices, "indices", 1); + auto indices_arg = TensorArg(indices_, "indices", 1); checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); checkSameGPU("embedding_backward", grad_arg, indices_arg); + auto indices = indices_.contiguous(); + auto num_indices = indices.numel(); auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)}); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -277,7 +279,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice cuda::cub::sort_pairs( indices.data_ptr(), sorted_indices.data_ptr(), range.data_ptr(), orig_indices.data_ptr(), - num_indices, false, 0, nbits); + num_indices, false/*, 0, nbits*/); if (scale_grad_by_freq) { count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 34a9d9dd82133..35094681a79c8 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -17,6 +17,7 @@ #include #include +#include #include @@ -235,7 +236,7 @@ template __global__ void EmbeddingBag_accGradParametersKernel_max( index_t *max_indices, scalar_t *gradOutput, scalar_t *gradWeight, int64_t stride, int64_t numBags, - index_t padding_idx) { + index_t padding_idx, const index_t numel) { using accscalar_t = acc_type; @@ -252,8 +253,9 @@ __global__ void EmbeddingBag_accGradParametersKernel_max( index_t word_idx = max_indices[bag * stride + featureDim]; if (word_idx >= 0 && word_idx != padding_idx) { // If bag is empty, we have max_indices[idx] set to -1 in forward. - gpuAtomicAddNoReturn(&(gradWeight[word_idx * stride + featureDim]), - gradOutput[bag * stride + featureDim]); + fastAtomicAdd( + gradWeight, static_cast(word_idx * stride + featureDim), + numel, gradOutput[bag * stride + featureDim], true); } } } @@ -289,7 +291,7 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, scalar_t, index_t><<>>( max_indices.data_ptr(), grad.data_ptr(), grad_weight.data_ptr(), stride, numBags, - padding_idx); + padding_idx, grad_weight.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); diff --git a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu index 0277aee6f02b1..67a27ce116feb 100644 --- a/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu +++ b/aten/src/ATen/native/cuda/ForeachBinaryOpList.cu @@ -18,7 +18,7 @@ std::vector foreach_tensor_list_op(TensorList tensors1, TensorList tenso tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors1[0].scalar_type(), "foreach_binary_op_list_cuda", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<3>(tensor_lists, BinaryOpListAlphaFunctor::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<2>(tensor_lists, BinaryOpListAlphaFunctor foreach_binary_op(TensorList tensors, const Scalar& scalar) tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalar_cuda", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<2>(tensor_lists, BinaryOpScalarFunctor::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<1>(tensor_lists, BinaryOpScalarFunctor foreach_binary_op(TensorList tensors, at::ArrayRef s tensor_lists.emplace_back(vec_res); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBFloat16, kHalf, kBool, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<2, opmath_t>(tensor_lists, scalars, BinaryOpScalarListFunctor scalars) { tensor_lists.emplace_back(tensors.vec()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBFloat16, kHalf, kBool, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<1, opmath_t>(tensor_lists, scalars, BinaryOpScalarListFunctor #include +#include namespace at { namespace native { namespace { -// For FP16 or BFloat16 inputs, ops should perform internal math in FP32. -template struct get_opmath_t { using opmath_t = scalar_t; }; -template<> struct get_opmath_t { using opmath_t = float; }; -template<> struct get_opmath_t { using opmath_t = float; }; - // Initializes args and checks if all args are aligned template __device__ bool init_args( @@ -158,7 +154,7 @@ __device__ __forceinline__ void pointwise_op_scalar( // template struct BinaryOpScalarFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata& tl, @@ -179,7 +175,7 @@ struct BinaryOpScalarFunctor { template struct BinaryOpScalarListFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListScalarListMetadata& tl, @@ -200,7 +196,7 @@ struct BinaryOpScalarListFunctor { template struct BinaryOpListAlphaFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata& tl, @@ -287,7 +283,7 @@ struct ZeroFunctor { template struct UnaryOpFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata& tl, @@ -333,7 +329,7 @@ struct UnaryOpFunctor { template struct PointwiseOpScalarFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata& tl, @@ -354,7 +350,7 @@ struct PointwiseOpScalarFunctor { template struct PointwiseOpScalarListFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListScalarListMetadata& tl, @@ -375,7 +371,7 @@ struct PointwiseOpScalarListFunctor { template struct PointwiseOpListFunctor { - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator() ( int chunk_size, TensorListMetadata& tl, diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index 977425984e99e..9440b87caedac 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -20,7 +20,7 @@ std::vector foreach_pointwise_op(TensorList input, TensorList tensors1, tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<4>(tensor_lists, PointwiseOpScalarFunctor::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<3>(tensor_lists, PointwiseOpScalarFunctor::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<3, opmath_t>(tensor_lists, scalars, PointwiseOpScalarListFunctor foreach_pointwise_op(TensorList input, TensorList tensors1, tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, input[0].scalar_type(), "foreach_pointwise_op_cuda", [&]() { - using opmath_t = get_opmath_t::opmath_t; + using opmath_t = at::opmath_type; multi_tensor_apply<4, opmath_t>(tensor_lists, scalars, PointwiseOpScalarListFunctor foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensor_lists.emplace_back(std::move(vec_res)); \ \ AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, tensors1[0].scalar_type(), "foreach_maximum_minimum_op_cuda", [&]() { \ - using opmath_t = get_opmath_t::opmath_t; \ + using opmath_t = at::opmath_type; \ auto op = [] GPU_LAMBDA (opmath_t a, opmath_t b) -> opmath_t { \ opmath_t c = a OP b ? a : b; \ if (_isnan(a)) { \ diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index 8d606824d2cc6..fd7a12b9dfac6 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -15,7 +15,7 @@ template class Op> std::vector forea tensor_lists.emplace_back(tensors.vec()); tensor_lists.emplace_back(std::move(vec_res)); - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = typename at::opmath_type; multi_tensor_apply<2>(tensor_lists, UnaryOpFunctor class Op> std::vector forea template class Op> void foreach_unary_op_(TensorList tensors) { std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); - using opmath_t = typename get_opmath_t::opmath_t; + using opmath_t = typename at::opmath_type; multi_tensor_apply<1>(tensor_lists, UnaryOpFunctor(0); - auto strides = at::detail::Array(0); - auto index_ptrs = at::detail::Array(nullptr); + auto sizes = at::detail::Array(0); + auto strides = at::detail::Array(0); + auto index_ptrs = at::detail::Array(nullptr); for (int i = 0; i < num_indices; i++) { sizes[i] = index_size[i]; strides[i] = index_stride[i]; diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index b7ecf386c6edc..b4936c069b0b1 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -5,7 +5,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index fde8e86409db7..8849293e20210 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -111,49 +112,64 @@ void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { gpu_kernel_impl(iter, f); } -template +template struct AUnaryFunctor { using traits = function_traits; - using arg1_t = typename traits::template arg<0>::type; - using arg2_t = typename traits::template arg<1>::type; - using return_t = typename traits::result_type; + using opmath_arg1_t = typename traits::template arg<0>::type; __device__ return_t operator()(arg2_t b) const { return f(a, b); } - AUnaryFunctor(func_t f_, arg1_t a_): f(f_), a(a_) {} + // NB: scalar is stored in higher precision! + AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} private: func_t f; - arg1_t a; + opmath_arg1_t a; }; -template +template struct BUnaryFunctor { using traits = function_traits; - using arg1_t = typename traits::template arg<0>::type; - using arg2_t = typename traits::template arg<1>::type; - using return_t = typename traits::result_type; + using opmath_arg2_t = typename traits::template arg<1>::type; __device__ return_t operator()(arg1_t a) const { return f(a, b); } - BUnaryFunctor(func_t f_, arg2_t b_): f(f_), b(b_) {} + // NB: scalar is stored in higher precision! + BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} private: func_t f; - arg2_t b; + opmath_arg2_t b; }; -template -void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { +// Though seemingly noop, this inserts casts from arg1_t to func_t's type +// (which may be higher precision), as well as casts to return_t +template +struct BinaryFunctor { + __device__ return_t operator()(arg1_t a, arg2_t b) const { + return f(a, b); + } + BinaryFunctor(func_t f_): f(f_) {} + private: + func_t f; +}; + +// Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which +// accepts inputs at higher precision (typically opmath_t), but then +// ensure that we load from memory at the correct precision (scalar_t) +// to avoid expensive loads. For the whole sordid story see +// https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 +template +void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); using traits = function_traits; + using opmath_arg1_t = typename traits::template arg<0>::type; + using opmath_arg2_t = typename traits::template arg<1>::type; static_assert( traits::arity == 2, "gpu_kernel_with_scalars only supports two input arguments"); - using arg1_t = typename traits::template arg<0>::type; - using arg2_t = typename traits::template arg<1>::type; if (iter.is_cpu_scalar(1)) { - AUnaryFunctor af(f, iter.scalar_value(1)); + AUnaryFunctor af(f, iter.scalar_value(1)); iter.remove_operand(1); // TODO: When all kernels that use gpu_kernel_with_scalars are // ported to structured, this device guard can be deleted. This @@ -163,14 +179,28 @@ void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { const OptionalDeviceGuard device_guard(device_of(iter.tensor(1))); gpu_kernel(iter, af); } else if (iter.is_cpu_scalar(2)) { - BUnaryFunctor bf(f, iter.scalar_value(2)); + BUnaryFunctor bf(f, iter.scalar_value(2)); iter.remove_operand(2); gpu_kernel(iter, bf); } else { - gpu_kernel(iter, f); + gpu_kernel(iter, BinaryFunctor(f)); } } +// Legacy variant that assumes that func_t has the correct types +// that we expect to load from memory +template +void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + static_assert( + traits::arity == 2, + "gpu_kernel_with_scalars only supports two input arguments"); + using arg1_t = typename traits::template arg<0>::type; + using arg2_t = typename traits::template arg<1>::type; + using return_t = typename traits::result_type; + opmath_gpu_kernel_with_scalars(iter, f); +} + namespace { // functions for `gpu_kernel_multiple_outputs`. // check the return type is `thrust::tuple`, not `std::tuple`. diff --git a/aten/src/ATen/native/cuda/Loss.cu b/aten/src/ATen/native/cuda/Loss.cu index d814eae01f4ec..2087f19dd3486 100644 --- a/aten/src/ATen/native/cuda/Loss.cu +++ b/aten/src/ATen/native/cuda/Loss.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -207,7 +208,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_1d( bool size_average, int n_classes, int64_t ignore_index) { - CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 & threadIdx.z == 0); + CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0); int t = static_cast(*target); if (t != static_cast(ignore_index)) { @@ -263,7 +264,7 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d( *total_weight = static_cast(total_weight_acc); if (size_average && nframe == 0) { // Mean reduction on empty tensors produces NaN - *output = std::numeric_limits::quiet_NaN(); + *output = std::numeric_limits::quiet_NaN(); } else if (size_average && total_weight_acc != 0) { *output = static_cast(output_acc / total_weight_acc); } else { @@ -286,7 +287,7 @@ void nll_loss_forward_out_cuda_template( auto weight_ = weight.defined() ? weight.contiguous() : weight; - if (reduction == Reduction::None & n_dims == 2) { + if (reduction == Reduction::None && n_dims == 2) { output.resize_({batch_size}); if (batch_size == 0) { // This guards from unnecessary operations and launching CUDA kernel with @@ -365,7 +366,8 @@ void nll_loss_forward_out_cuda_template( target.scalar_type(), "nll_loss_forward_reduce_cuda_kernel_2d_index", [&] { - nll_loss_forward_reduce_cuda_kernel_2d + using accscalar_t = at::acc_type; + nll_loss_forward_reduce_cuda_kernel_2d <<<1, NLL_LOSS_THREADS, 0, @@ -468,7 +470,6 @@ void nll_loss_backward_out_cuda_template( int64_t n_dims = input.dim(); int64_t n_classes = input.size(-1); int64_t batch_size = n_dims == 1 ? 1 : input.size(0); - int64_t num_targets = target.size(0); auto weight_ = weight.defined() ? weight.contiguous() : weight; diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3912af58e1d99..65c45e7027964 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -1,6 +1,5 @@ #include #include -#include #include #include #include diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index dff3f69bcc43c..44e27a95647b1 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -487,7 +486,8 @@ std::tuple batch_norm_backward_cuda(const Tensor& grad_o // save_mean and save_invstd, so it needs recalculated. const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); Tensor mean; - if (save_mean->defined()) { + TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n"); + if (save_mean->numel() != 0) { mean = *save_mean; } else if (needs_reduction) { TORCH_CHECK(!train && running_mean->defined()); @@ -496,7 +496,8 @@ std::tuple batch_norm_backward_cuda(const Tensor& grad_o } Tensor invstd; - if (save_invstd->defined()) { + TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n"); + if (save_invstd->numel() != 0) { invstd = *save_invstd; } else { TORCH_CHECK(!train && running_var->defined()); @@ -646,7 +647,9 @@ Tensor batch_norm_backward_elemt_cuda(const Tensor& self, const Tensor& input, c c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - if (at::cuda::detail::canUse32BitIndexMath(self) && batch_norm_use_channels_last_kernels(self)){ + if (at::cuda::detail::canUse32BitIndexMath(self) && + batch_norm_use_channels_last_kernels(self) && + batch_norm_use_channels_last_kernels(input)) { return batch_norm_backward_elemt_channels_last_cuda_template(self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); } diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index af074f5d2c6fd..6daa2b0858044 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -1649,7 +1649,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; @@ -1716,7 +1717,8 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const auto reduction_size = input.numel() / stride; auto norm_fct = 1.0 / reduction_size; - at::Tensor grad_input = at::empty_like(input, input.suggest_memory_format()); + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 8c423061a79f6..3be7100483b3c 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -919,10 +919,11 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter // when accumulation in output is not possible. if (!can_accumulate_in_output && !can_use_32bit_indexing) { - int64_t output_memory_size = 1; + int64_t output_memory_size = iter.element_size(0); for (int dim = 0; dim < iter.ndim(); dim++) { output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]); } + output_memory_size /= iter.element_size(0); //iter.strides is in bytes owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t), sizeof(out_scalar_t), (char*) iter.data_ptr(0), @@ -988,14 +989,14 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id // Map block.x to the fastest reducing dimension. It implies: // 1. block_x_reduce is required. // 2. block.y now max out to num_outputs. - dim0 = iter.shape()[0]; + dim0 = inputs_per_output; dim1 = num_outputs; fastest_moving_stride = iter.strides(/*arg=*/input_index)[0]; } else { // Map block.x to the fastest non reducing dimension. It implies: // 1. block_x_reduce is turned off. // 2. block.y now max out to inputs_per_output. - dim0 = iter.shape()[iter.num_reduce_dims()]; + dim0 = num_outputs; dim1 = inputs_per_output; fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]; } diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index b95570109de91..5f03cc450f206 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -89,10 +89,6 @@ struct _cuda_scatter_gather_internal_kernel { int64_t index_stride, const func_t& f ) { - if (iter.numel() == 0) { - return; - } - if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { _cuda_scatter_gather_internal_kernel()( @@ -132,24 +128,13 @@ template struct cuda_scatter_gather_base_kernel { template void operator()( - Tensor& self, int64_t dim, + const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const std::string& method_name, const func_t& f ) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } at::assert_no_internal_overlap(self); - dim = maybe_wrap_dim(dim, self.dim()); - - scatter_gather_dtype_check(method_name, self, index, src); - if (!is_scatter_like) { - gather_shape_check(self, dim, index, src); - } - auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); auto self_strides = ensure_nonempty_vec(self.strides().vec()); auto src_strides = ensure_nonempty_vec(src.strides().vec()); @@ -201,24 +186,13 @@ struct cuda_scatter_gather_base_kernel { } void operator()( - Tensor& self, int64_t dim, + const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const std::string& method_name, const ReduceMultiply& f ) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } at::assert_no_internal_overlap(self); - dim = maybe_wrap_dim(dim, self.dim()); - - scatter_gather_dtype_check(method_name, self, index, src); - if (!is_scatter_like) { - gather_shape_check(self, dim, index, src); - } - auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); auto self_strides = ensure_nonempty_vec(self.strides().vec()); auto src_strides = ensure_nonempty_vec(src.strides().vec()); @@ -280,10 +254,6 @@ struct _cuda_scatter_fill_internal_kernel { int64_t index_stride, const func_t& f ) { - if (iter.numel() == 0) { - return; - } - if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { _cuda_scatter_fill_internal_kernel()( @@ -322,19 +292,13 @@ template struct cuda_scatter_fill_base_kernel { template void operator()( - Tensor& self, int64_t dim, + const Tensor& self, int64_t dim, const Tensor& index, Scalar src, const std::string& method_name, const func_t& f ) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } at::assert_no_internal_overlap(self); - dim = maybe_wrap_dim(dim, self.dim()); - auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); // restride self such that @@ -371,19 +335,13 @@ struct cuda_scatter_fill_base_kernel { } void operator()( - Tensor& self, int64_t dim, + const Tensor& self, int64_t dim, const Tensor& index, Scalar src, const std::string& method_name, const ReduceMultiply& f ) { - // no-op if index is empty - if (index.numel() == 0) { - return; - } at::assert_no_internal_overlap(self); - dim = maybe_wrap_dim(dim, self.dim()); - auto index_sizes = ensure_nonempty_vec(index.sizes().vec()); // restride self such that @@ -420,25 +378,25 @@ struct cuda_scatter_fill_base_kernel { } }; // struct cuda_scatter_fill_base_kernel -void gather_cuda_kernel(Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { +void gather_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) { cuda_scatter_gather_base_kernel()( result, dim, index, self, "gather_out_cuda", tensor_assign); } -void scatter_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { +void scatter_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { cuda_scatter_gather_base_kernel<>()( self, dim, index, src, "scatter_cuda_", tensor_assign); } -void scatter_fill_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, const Scalar& src) { +void scatter_fill_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src) { cuda_scatter_fill_base_kernel<>()( self, dim, index, src, "scatter_fill_cuda_", tensor_assign); } -void scatter_add_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { +void scatter_add_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) { // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("scatter_add_cuda_kernel"); @@ -447,7 +405,7 @@ void scatter_add_cuda_kernel(Tensor& self, int64_t dim, const Tensor& index, con "scatter_add_cuda_", reduce_add); } -void scatter_reduce_cuda_kernel(Tensor& self, const int64_t dim, const Tensor& index, +void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index, const Tensor& src, const SCATTER_GATHER_OP& reduce) { switch (reduce) { case SCATTER_GATHER_OP::REDUCE_ADD : @@ -461,7 +419,7 @@ void scatter_reduce_cuda_kernel(Tensor& self, const int64_t dim, const Tensor& i } } -void scatter_scalar_reduce_cuda_kernel(Tensor& self, const int64_t dim, const Tensor& index, +void scatter_scalar_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index, const Scalar& value, const SCATTER_GATHER_OP& reduce) { switch (reduce) { case SCATTER_GATHER_OP::REDUCE_ADD : diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index dec985447944e..05fa4c6e165c4 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -528,7 +529,10 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) { size[dimension] = cat_dim_size; // skip resizing if size of result is same as expected - if (out.sizes() != size) { + // raise a warning while resizing if output has one or more elements + // See https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362 + // for understanding why at::native::resize_output is not called directly. + if (at::native::resize_output_check(out, size)) { out.resize_(size, memory_format); } diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index f53f7b478dadf..9cb32bc5ac14c 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -208,6 +207,87 @@ struct offset_t { } +namespace { + +// Segmented sort by full sort algorithm:. +// Say we are sorting a (2, 3) tensor. We have in flattened form: +// values 0.4 1.2 5.3 6.2 1.3 2.3 +// indices 0 1 2 0 1 2 +// segment_id 0 0 0 1 1 1 + +// First we sort by values, globally: +// values 6.2 5.3 2.3 1.2 1.3 0.4 +// indices 0 2 2 1 1 0 +// segment_id 1 0 1 0 1 0 + +// Then we stable sort by segment id: +// values 5.3 1.2 0.4 6.2 2.3 1.3 +// indices 2 1 0 0 2 1 +// segment_id 0 0 0 1 1 1 + +// This method can only work if the slice we are sorting (`dim`) is +// innermost, and both values and indices are contiguous. We do this +// by re-arranging the input into this form as needed, which will +// unfortunately allocate memory if the request is not in this form. +// Vectorized sort is slower than iterated sort if the number of +// slices is small (since we're sorting twice, instead of invoking a +// smaller sort `numSlices` times), but the cub sort +// implementation here is a catch-all, so we're not looking for +// efficiency, but instead correctness. + +template +__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) { + CUDA_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; + int j = i % nsort; + + int offset = segment * nsort; + const scalar_t *in_ = in + offset; + scalar_t *out_ = out + offset; + int64_t *index_ = index + offset; + const int2 *i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +template +inline void segmented_sort_pairs_by_full_sort( + int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices, + const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr +) { + int64_t segment_bits = std::max(1L, static_cast(std::ceil(std::log2(nsegments)))); + + auto int_options = indices.options().dtype(kInt); + auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options); + indices_and_segment.select(-1, 0).copy_( // segment id + at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort})); + indices_and_segment.select(-1, 1).copy_( // reverse indices + at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort})); + + auto i_s_ptr = reinterpret_cast(indices_and_segment.data_ptr()); + auto indices_and_segment2 = at::empty_like(indices_and_segment); + auto i_s_ptr2 = reinterpret_cast(indices_and_segment2.data_ptr()); + + at::cuda::cub::sort_pairs( + self_ptr, nullptr, i_s_ptr, i_s_ptr2, + n, descending); + + TORCH_INTERNAL_ASSERT(segment_bits <= 32); + + // sort on lower 32bits, i.e. segment index + at::cuda::cub::sort_keys( + reinterpret_cast(i_s_ptr2), reinterpret_cast(i_s_ptr), + n, false, 0, segment_bits); + + sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); +} + +} // namespace + // We perform a segmented sort in cub with inputs that have // more than 1024/2048 elements along the selected dimension. // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace). @@ -350,11 +430,15 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt int64_t n = std::min(remaining, nbatch); int64_t nsegments = n / nsort; - auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous(); - - at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, - reverse_indices.data_ptr(), indices_ptr, n, nsegments, - offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); + if (nsegments < 128) { + segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, + indices, self_ptr, values_ptr, indices_ptr); + } else { + auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous(); + at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, + reverse_indices.data_ptr(), indices_ptr, n, nsegments, + offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); + } remaining -= n; self_ptr += n; diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index d6b4fe2620191..121208dd58dc2 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include @@ -15,7 +14,7 @@ namespace at { namespace native { namespace { template -C10_LAUNCH_BOUNDS_1(512) +C10_LAUNCH_BOUNDS_1(1024) __global__ void gatherTopK(at::cuda::detail::TensorInfo input, IndexType inputSliceSize, IndexType outputSliceSize, // aka `k` @@ -255,7 +254,7 @@ TORCH_IMPL_FUNC(topk_out_cuda) dim3 grid; \ TORCH_INTERNAL_ASSERT(getGridFromTiles(inputSlices, grid), "Too many slices to sort"); \ \ - dim3 block(std::min(at::cuda::ATenCeilDiv(sliceSize, (int64_t) C10_WARP_SIZE)*(int64_t) C10_WARP_SIZE, (int64_t) 512)); \ + dim3 block(std::min(at::cuda::ATenCeilDiv(sliceSize, (int64_t) C10_WARP_SIZE)*(int64_t) C10_WARP_SIZE, (int64_t) 1024)); \ \ /* This is used as a template parameter to calculate indices. */ \ /* We only specialize it if all collapsed dim sizes are the */ \ diff --git a/aten/src/ATen/native/cuda/UpSample.cuh b/aten/src/ATen/native/cuda/UpSample.cuh index 71443e19755d5..c69a2597b74bb 100644 --- a/aten/src/ATen/native/cuda/UpSample.cuh +++ b/aten/src/ATen/native/cuda/UpSample.cuh @@ -94,11 +94,16 @@ __host__ __forceinline__ static accscalar_t area_pixel_compute_scale( int output_size, bool align_corners, const c10::optional scale) { - if (output_size > 1) { - return align_corners ? (accscalar_t)(input_size - 1) / (output_size - 1) - : compute_scales_value(scale, input_size, output_size); - } else { - return static_cast(0); + if(align_corners) { + if(output_size > 1) { + return (accscalar_t)(input_size - 1) / (output_size - 1); + } + else { + return static_cast(0); + } + } + else{ + return compute_scales_value(scale, input_size, output_size); } } diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index 522225b5fd85a..6270bba9eafee 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -1,6 +1,5 @@ #include #include -#include #include #include #include diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 3a34e327e2697..1c70aa353b517 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -212,6 +212,9 @@ std::tuple cudnn_batch_norm( #endif // CUDNN_VERSION >= 7400 } else { reserve = at::empty({0}, input->options().dtype(kByte)); + // This keeps a consistent output with native_batch_norm + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 7d16f0a9a910f..27863d060d2dd 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -292,6 +293,7 @@ struct algorithm_search { } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); + at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind"); AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx( args.handle, args.idesc.desc(), args.input.data_ptr(), @@ -362,6 +364,7 @@ struct algorithm_search { } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); + at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind"); AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardDataAlgorithmEx( args.handle, args.wdesc.desc(), args.weight.data_ptr(), @@ -434,6 +437,7 @@ struct algorithm_search { } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); Workspace ws(max_ws_size); + at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind"); AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionBackwardFilterAlgorithmEx( args.handle, args.idesc.desc(), args.input.data_ptr(), diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 3a60d19959f83..5533780a4547e 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -31,7 +31,10 @@ std::tuple native_group_norm( const Tensor& gamma = *gamma_maybe_owned; const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); }); - TORCH_CHECK(X.is_contiguous()); + auto memory_format = X.device().is_cpu() ? + X.suggest_memory_format() : at::MemoryFormat::Contiguous; + + TORCH_CHECK(X.is_contiguous(memory_format)); Tensor Y = at::native::empty_like( X, @@ -39,7 +42,7 @@ std::tuple native_group_norm( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + memory_format); Tensor mean = at::empty({N, group}, X.options()); Tensor rstd = at::empty({N, group}, X.options()); GroupNormKernel( @@ -73,7 +76,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } if (grad_input_mask[1]) { dgamma = at::native::empty_like( @@ -82,7 +85,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } if (grad_input_mask[2]) { dbeta = at::native::empty_like( @@ -91,7 +94,7 @@ std::tuple native_group_norm_backward( c10::nullopt /* layout */, c10::nullopt /* device */, c10::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); + at::MemoryFormat::Contiguous); } GroupNormBackwardKernel( X.device().type(), @@ -153,7 +156,9 @@ Tensor group_norm( c10::multiply_integers(input_shape.cbegin() + 2, input_shape.cend()); const Tensor kEmpty; - const auto& X = input.is_contiguous() ? input : input.contiguous(); + auto memory_format = input.suggest_memory_format(); + const auto& X = input.device().is_cpu() ? + input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; TORCH_CHECK(!gamma.defined() || gamma.numel() == C); diff --git a/aten/src/ATen/native/metal/MetalContext.mm b/aten/src/ATen/native/metal/MetalContext.mm index 80ee55efa591e..f71d35f97a866 100644 --- a/aten/src/ATen/native/metal/MetalContext.mm +++ b/aten/src/ATen/native/metal/MetalContext.mm @@ -37,9 +37,6 @@ + (instancetype)sharedInstance { - (BOOL)available { #if !defined(__APPLE__) return false; -#elif TARGET_IPHONE_SIMULATOR - // TODO[T90135707]: Enable Metal on iOS Simulators - return false; #elif TARGET_OS_IPHONE if (!MPSSupportsMTLDevice(_device)) { return false; @@ -47,9 +44,6 @@ - (BOOL)available { if ([UIDevice currentDevice].systemVersion.floatValue < 11.0) { return false; } - if (![_device supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) { - return false; - } #elif TARGET_OS_MAC if (!MPSSupportsMTLDevice(_device)) { return false; diff --git a/aten/src/ATen/native/metal/MetalPrepackOpContext.h b/aten/src/ATen/native/metal/MetalPrepackOpContext.h index e6b3f0b78a518..5976d7af23e53 100644 --- a/aten/src/ATen/native/metal/MetalPrepackOpContext.h +++ b/aten/src/ATen/native/metal/MetalPrepackOpContext.h @@ -21,14 +21,14 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { public: SerializationTypeConv2dPrePack pack() { return std::make_tuple( - weight, - bias, - stride, - padding, - dilation, - groups, - output_min, - output_max); + weight_, + bias_, + stride_, + padding_, + dilation_, + groups_, + output_min_, + output_max_); } Conv2dOpContext() = delete; Conv2dOpContext( @@ -40,32 +40,81 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { int64_t groups, const c10::optional& output_min, const c10::optional& output_max) - : weight(std::move(weight)), - bias(std::move(bias)), - stride(stride), - padding(padding), - dilation(dilation), - groups(groups), - output_min(output_min), - output_max(output_max) {} + : weight_(std::move(weight)), + bias_(std::move(bias)), + stride_(stride), + padding_(padding), + dilation_(dilation), + groups_(groups), + output_min_(output_min), + output_max_(output_max) {} void release_resources() override { - if (releaseCallback) { - releaseCallback(conv2dOp); - conv2dOp = nullptr; + if (releaseCallback_) { + releaseCallback_(conv2dOp_); + conv2dOp_ = nullptr; } } - Tensor weight; - c10::optional bias; - std::vector stride; - std::vector padding; - std::vector dilation; - int64_t groups; - c10::optional output_min; - c10::optional output_max; - void* conv2dOp = nullptr; // reserved to hold MPSCNNConv2dOp objects - std::function releaseCallback = nullptr; + const Tensor& get_weight() const { + return weight_; + } + + const c10::optional& get_bias() const { + return bias_; + } + + const std::vector& get_stride() const { + return stride_; + } + + const std::vector& get_padding() const { + return padding_; + } + + const std::vector& get_dilation() const { + return dilation_; + } + + int64_t get_groups() const { + return groups_; + } + + const c10::optional& get_output_min() const { + return output_min_; + } + + const c10::optional& get_output_max() const { + return output_max_; + } + + void set_conv2dOpPtr(void* ptr) { + conv2dOp_ = ptr; + } + + void* get_conv2dOpPtr() const { + return conv2dOp_; + } + + void set_releaseCallback(const std::function& func) { + releaseCallback_ = func; + } + + std::function& get_releaseCallback() { + return releaseCallback_; + } + + private: + Tensor weight_; + c10::optional bias_; + std::vector stride_; + std::vector padding_; + std::vector dilation_; + int64_t groups_; + c10::optional output_min_; + c10::optional output_max_; + std::function releaseCallback_ = nullptr; + void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects }; using SerializationTypeLinearPrePack = std::tuple< diff --git a/aten/src/ATen/native/metal/MetalShaders.h b/aten/src/ATen/native/metal/MetalShaders.h index 5c256723a59e5..0ee703f2ee261 100644 --- a/aten/src/ATen/native/metal/MetalShaders.h +++ b/aten/src/ATen/native/metal/MetalShaders.h @@ -393,31 +393,32 @@ kernel void clamp(texture2d_array in_arr[[texture(0), functi } } -kernel void hardswish(texture2d_array in[[texture(0)]], - texture2d_array out[[texture(1)]], +constant bool hardswish_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); +constant bool hardswish_is_tex = !hardswish_is_arr; +kernel void hardswish(texture2d_array in_arr[[texture(0), function_constant(hardswish_is_arr)]], + texture2d in_tex[[texture(0), function_constant(hardswish_is_tex)]], + texture2d_array out_arr[[texture(1), function_constant(hardswish_is_arr)]], + texture2d out_tex[[texture(1), function_constant(hardswish_is_tex)]], ushort3 gid[[thread_position_in_grid]]) { - if (gid.x >= out.get_width() || gid.y >= out.get_height()) { + const ushort oH = ushort_arg_2; + const ushort oW = ushort_arg_3; + if (gid.x >= oW || gid.y >= oH) { return; } ushort2 gid_ = gid.xy; - half4 value = in.read(gid_, gid.z); - half4 mask1 = half4(value < 3.0); - half4 mask2 = half4(value > -3.0); - half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); - out.write(outval, gid_, gid.z); -} - -kernel void hardswish_nonarray(texture2d in[[texture(0)]], - texture2d out[[texture(1)]], - ushort2 gid[[thread_position_in_grid]]) { - if (gid.x >= out.get_width() || gid.y >= out.get_height()) { - return; + if (hardswish_is_arr) { + half4 value = in_arr.read(gid_, gid.z); + half4 mask1 = half4(value < 3.0); + half4 mask2 = half4(value > -3.0); + half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); + out_arr.write(outval, gid_, gid.z); + } else { + half4 value = in_tex.read(gid_); + half4 mask1 = half4(value < 3); + half4 mask2 = half4(value > -3.0); + half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); + out_tex.write(outval, gid_); } - half4 value = in.read(gid); - half4 mask1 = half4(value < 3); - half4 mask2 = half4(value > -3.0); - half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); - out.write(outval, gid); } constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h index ee992d9db5abd..599f2ceb64f4c 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h @@ -41,6 +41,7 @@ bool test_softmax(); bool test_sigmoid(); bool test_hardsigmoid(); bool test_hardswish(); +bool test_hardswish2(); bool test_upsampling_nearest2d_vec(); bool test_upsampling_nearest2d_vec2(); bool test_adaptive_avg_pool2d(); diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index 69497a976a130..5a8f6de86996b 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -262,6 +262,18 @@ bool test_hardswish() { }); } +bool test_hardswish2() { + __block std::vector size{1, 3, 44, 44}; + return TEST(size, __PRETTY_FUNCTION__, ^bool { + auto X = + at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6; + auto X2 = X.metal(); + auto Y1 = at::hardswish_(X); + auto Y2 = at::hardswish_(X2).cpu(); + return almostEqual(Y1, Y2); + }); +} + bool test_addmm() { bool result = true; for (int i = 0; i < ITER_COUNT; ++i) { diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm index d8b69adcc9d1e..5e749983c822d 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MetalOpTestRunner.mm @@ -69,19 +69,22 @@ - (void)registerTests { REG_TEST("test_sigmoid", test_sigmoid); REG_TEST("test_hardsigmoid", test_hardsigmoid); REG_TEST("test_hardswish", test_hardswish); + REG_TEST("test_hardswish2", test_hardswish2); REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec); REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2); REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d); REG_TEST("test_hardtanh_", test_hardtanh_); REG_TEST("test_hardtanh", test_hardtanh); REG_TEST("test_reshape", test_reshape); + REG_TEST("test_chunk", test_chunk); + REG_TEST("test_chunk3", test_chunk3); + REG_TEST("test_reflection_pad2d", test_reflection_pad2d); +#if !TARGET_IPHONE_SIMULATOR REG_TEST("test_mean_dim", test_mean_dim); REG_TEST("test_mean_dim2", test_mean_dim2); REG_TEST("test_mean_dim3", test_mean_dim3); - REG_TEST("test_chunk", test_chunk); REG_TEST("test_chunk2", test_chunk2); - REG_TEST("test_chunk3", test_chunk3); - REG_TEST("test_reflection_pad2d", test_reflection_pad2d); +#endif } - (NSDictionary*)tests { diff --git a/aten/src/ATen/native/metal/ops/MetalConvolution.mm b/aten/src/ATen/native/metal/ops/MetalConvolution.mm index c726382dde45f..4f07f5f77161d 100644 --- a/aten/src/ATen/native/metal/ops/MetalConvolution.mm +++ b/aten/src/ATen/native/metal/ops/MetalConvolution.mm @@ -55,28 +55,28 @@ Tensor conv2d( Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { MPSImage* X = imageFromTensor(input); Conv2DParams params{input.sizes(), - context.weight.sizes(), - context.padding, - context.stride, - context.dilation, - context.groups}; + context.get_weight().sizes(), + context.get_padding(), + context.get_stride(), + context.get_dilation(), + context.get_groups()}; auto outputSize = params.output_sizes(); if(c10::multiply_integers(outputSize) == 0){ return makeTensor({outputSize}, input.options()); } - MPSCNNConvOp* op = (__bridge MPSCNNConvOp*)(context.conv2dOp); - NeuronType nt = neuronType(context.output_min, context.output_max); + MPSCNNConvOp* op = (__bridge MPSCNNConvOp*)(context.get_conv2dOpPtr()); + NeuronType nt = neuronType(context.get_output_min(), context.get_output_max()); if (!op) { - float* w = context.weight.data_ptr(); - float* b = context.bias.has_value() ? ((*context.bias).data_ptr()) + float* w = context.get_weight().data_ptr(); + float* b = context.get_bias().has_value() ? ((*context.get_bias()).data_ptr()) : nullptr; op = [MPSCNNConvOp conv2d:params weights:w bias:b neuronFilter:nt]; - context.conv2dOp = (void*)CFBridgingRetain(op); - context.releaseCallback = ^(void* res) { + context.set_conv2dOpPtr((void*)CFBridgingRetain(op)); + context.set_releaseCallback(^(void* res) { if (res) { CFBridgingRelease(res); } - }; + }); } MetalTensorImplStorage mt{outputSize}; MetalCommandBuffer* commandBuffer = getCommandBuffer(input); @@ -86,8 +86,8 @@ Tensor conv2d(const Tensor& input, Conv2dOpContext& context) { // fuse hardtanh with convolution if (nt == NeuronType::Clamp) { MPSImage* Y2 = createTemporaryImage(commandBuffer, [Y1 sizes]); - float min = context.output_min.value().toFloat(); - float max = context.output_max.value().toFloat(); + float min = context.get_output_min().value().toFloat(); + float max = context.get_output_max().value().toFloat(); MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]]; [clampOp encode:commandBuffer.buffer]; diff --git a/aten/src/ATen/native/metal/ops/MetalHardswish.mm b/aten/src/ATen/native/metal/ops/MetalHardswish.mm index 8d3526a4c6b2a..d571e483233dd 100644 --- a/aten/src/ATen/native/metal/ops/MetalHardswish.mm +++ b/aten/src/ATen/native/metal/ops/MetalHardswish.mm @@ -24,9 +24,9 @@ id encoder = [commandBuffer.buffer computeCommandEncoder]; id state = [[MetalContext sharedInstance] - specializedPipelineState:mpscnn::kernelFor( - X, "hardswish", "hardswish_nonarray") + specializedPipelineState:"hardswish" Constants:@[ + @(X.numberOfImages), @(X.featureChannels), @(X.height), @(X.width) diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index d78fe079ed442..28e20e90b2997 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -120,6 +120,8 @@ std::tuple miopen_batch_norm( save_mean.data_ptr(), save_var.data_ptr())); } else { + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); MIOPEN_CHECK(miopenBatchNormalizationForwardInference( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp new file mode 100644 index 0000000000000..5327ce821ff1e --- /dev/null +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#if !AT_MKLDNN_ENABLED() + +namespace at { +namespace native { + +void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + Tensor &result, + float beta, + float alpha) { + TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support"); +} +} // namespace native +} // namespace at + +#else // AT_MKLDNN_EBABLED + +#include +#include + +namespace at { +namespace native { + +void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result, + float beta, + float alpha) { + TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || (mat1.dim() == 3 && mat2.dim() == 3), + "mkldnn_matmul: expect mat1 to be 2-D or 3-D tensor"); + TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 && + mat2.scalar_type() == at::kBFloat16 && + result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path"); + TORCH_CHECK(mkldnn_bf16_device_check(), + "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); + ideep::attr_t op_attr; + // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor + // but mkldnn matmul primitive only support bias be 1-D tensors + // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over + if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum(); + // If alpha = 0, dose not need actually do gemm computation + if (alpha == 0) + return; + + auto is_mkldnn_optimized_format = [&](const Tensor& t) { + if (t.is_contiguous()) return true; + const auto sizes = t.sizes(); + const auto strides = t.strides(); + if (t.dim() == 2){ + return strides[0] == 1 && strides[1] == sizes[0]; + } else { + // dim = 3 + return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1]; + } + }; + + // Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now + // Will remove this "contiguous" after mkldnn have fully supported + Tensor mat1_ = is_mkldnn_optimized_format(mat1) ? mat1 : mat1.contiguous(); + Tensor mat2_ = is_mkldnn_optimized_format(mat2) ? mat2 : mat2.contiguous(); + Tensor mat1_reshaped = mat1_; + Tensor mat2_reshaped = mat2_; + if (result.dim() == 2 && mat1.dim() == 3 && mat2.dim() == 3){ + // addbmm(batch1*batch2) [b,n,m] * [b,m,p] = [n,p] can be treated as: + // [n, b*m] * [b*m, p] = [n, p] + // For batch1: reorder from [b, n, m] to [n, b, m], reshape to [n, b*m] + // For batch2: reshape from [b, m, p] to [b*m, p] + auto mat1_size = mat1.sizes(); + auto mat2_size = mat2.sizes(); + mat1_ = mat1_size[0] > 1 ? mat1_.transpose(0, 1) : mat1_; + mat1_reshaped = mat1_.reshape({mat1_size[1], mat1_size[0] * mat1_size[2]}); + mat2_reshaped = mat2_.reshape({mat2_size[0] * mat2_size[1], mat2_size[2]}); + } + + // mkldnn_matmul only proceed CPU tensor + const ideep::tensor x = itensor_view_from_dense(mat1_reshaped); + const ideep::tensor w = itensor_view_from_dense(mat2_reshaped); + ideep::tensor y = itensor_view_from_dense(result); + ideep::matmul_forward::compute(x, w, y, alpha, beta, + ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr); + if (y.get_data_handle() != result.data_ptr()){ + // ideep will query onednn expect format of output + // if given output format is not expected, ideep will re-init an output buffer + // under this case, we need copy the re-inited buffer back to given buffer + ideep::tensor public_y = itensor_view_from_dense(result); + y.reorder_to(public_y); + } +} + +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_EBABLED diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h new file mode 100644 index 0000000000000..8cd5b5a9b3aeb --- /dev/null +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +// result = beta * result + alpha * gemm(mat1, mat2) +// need mat, mat2 to be 2-D or 3-D Tensors +TORCH_API void mkldnn_matmul( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result, + float beta=1, + float alpha=1); + +}} diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 9836f3560d038..f01bbb3d2b4bd 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -25,6 +25,13 @@ std::tuple mkldnn_batch_norm_backward( TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support"); } +std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( + const Tensor& input, + IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, + double eps, bool inplace) { + TORCH_CHECK(false, "mkldnn_layer_norm_last_index_weight_bias_f32: ATen not compiled with MKLDNN support"); +} + } // namespace native } // namespace at @@ -32,10 +39,54 @@ std::tuple mkldnn_batch_norm_backward( #include #include +#include +#include namespace at { namespace native { +std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( + const Tensor& input, + IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, + double eps, bool inplace) { + + TORCH_INTERNAL_ASSERT(normalized_shape.size() == 1, "only accept shapes with the last dimension"); + TORCH_INTERNAL_ASSERT(input.scalar_type() == at::kFloat); + auto M_N = at::native::_check_layer_norm_inputs(input, normalized_shape, weight, bias); + auto M = M_N.first; + + auto mean = empty_mkldnn( + {M}, + input.scalar_type(), + input.options().layout_opt(), + input.options().device_opt(), + input.options().pinned_memory_opt()); + auto rstd = empty_mkldnn( + {M}, + input.scalar_type(), + input.options().layout_opt(), + input.options().device_opt(), + input.options().pinned_memory_opt()); + + auto mean_it = at::native::itensor_from_mkldnn(mean); + auto rstd_it = at::native::itensor_from_mkldnn(rstd); + + auto input_it = at::native::itensor_from_mkldnn(input); + auto weight_it = at::native::itensor_from_mkldnn(weight); + auto bias_it = at::native::itensor_from_mkldnn(bias); + + auto out_it = inplace ? input_it : ideep::tensor(input_it.get_desc()); + ideep::layer_normalization_forward::compute(input_it, weight_it, bias_it, out_it, mean_it, rstd_it, static_cast(eps)); + + auto dst = at::native::new_with_itensor_mkldnn( + std::move(out_it), + optTypeMetaToScalarType(input.options().dtype_opt()), + input.options().device_opt()); + + return std::make_tuple(dst, mean, rstd); +} + + std::tuple mkldnn_batch_norm( const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, bool train, diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index f2e4e8f9056df..49d51b286c097 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -1,11 +1,19 @@ #pragma once +#include #include +#include #include #include + namespace at { namespace native { +std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( + const Tensor& input, + IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, + double eps, bool inplace = false); + std::vector pool_output_sizes( IntArrayRef input_size, IntArrayRef kernel_size, @@ -21,4 +29,29 @@ inline bool mkldnn_bf16_device_check() { && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq(); } +inline bool use_mkldnn_bf16_gemm( + const Tensor& mat1, + const Tensor& mat2, + const c10::optional& result_opt) { + c10::MaybeOwned result_maybe_owned = at::borrow_from_optional_tensor(result_opt); + const Tensor& result = *result_maybe_owned; + + static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16; + // if dim = 2, mat1's size = (m * n), mat2's size = (n * k) + // else dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) + // only m * n * k are large enough we can get benefit from mkldnn optimized gemm kernel + // if some cases pytorch dose not have default impl for bf16 (such as "dot"), will use mkldnn impl anyway + int64_t m = mat1.dim() == 2? mat1.size(0) : mat1.size(1); + int64_t n = mat1.dim() == 2? mat1.size(1) : mat1.size(2); + int64_t k = mat2.dim() == 2? mat2.size(1) : mat2.size(2); + return ( + mat1.scalar_type() == kBFloat16 && + mat2.scalar_type() == kBFloat16 && + (!result.defined() || result.scalar_type() == kBFloat16) && + mat1.numel() != 0 && + mat2.numel() != 0 && + mkldnn_bf16_device_check() && + m * n * k >= mkldnn_gemm_min_size); +} + } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 40245cc7607af..dbacca2750850 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -487,6 +487,8 @@ - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True + precomputed: + - dim -> int dim dispatch: CPU, CUDA: all_out @@ -508,6 +510,8 @@ - func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True + precomputed: + - dim -> int dim dispatch: CPU, CUDA: any_out @@ -970,6 +974,15 @@ - func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) +# alias for torch.cat +- func: concat(Tensor[] tensors, int dim=0) -> Tensor + +- func: concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + +- func: concat.names(Tensor[] tensors, Dimname dim) -> Tensor + +- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + - func: block_diag(Tensor[] tensors) -> Tensor variants: function @@ -6597,14 +6610,13 @@ variants: method, function - func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + structured: True dispatch: - CPU: gather_out_cpu_cuda - CUDA: gather_out_cpu_cuda + CPU, CUDA: gather_out - func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor variants: method, function - dispatch: - CPU, CUDA: gather + structured_delegate: gather.out - func: gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor variants: function @@ -6652,19 +6664,19 @@ device_check: NoCheck # TensorIterator variants: method -- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100) -> Tensor +- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, float label_smoothing=0.0) -> Tensor python_module: nn - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) dispatch: CPU: legacy_lstsq_out - CUDA: legacy::cuda::_th_gels_out + CUDA: legacy_lstsq_out_cuda - func: lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR) variants: method, function dispatch: CPU: legacy_lstsq - CUDA: legacy::cuda::_th_gels + CUDA: legacy_lstsq_cuda - func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) dispatch: @@ -8460,14 +8472,16 @@ CompositeExplicitAutograd: elu_ - func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase python_module: nn dispatch: CPU, CUDA: glu_out - func: glu(Tensor self, int dim=-1) -> Tensor + structured_delegate: glu.out + device_check: NoCheck # TensorIterator python_module: nn - dispatch: - CPU, CUDA: glu - func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -8814,6 +8828,10 @@ - func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True + precomputed: + - kernel_size -> int kH, int kW + - stride -> int dH, int dW + - padding -> int padH, int padW dispatch: CPU: avg_pool2d_out_cpu CUDA: avg_pool2d_out_cuda @@ -9498,13 +9516,13 @@ python_module: nn dispatch: CPU: slow_conv2d_forward_out_cpu - CUDA: legacy::cuda::_thnn_conv2d_forward_out + CUDA: slow_conv2d_forward_out_cuda - func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) python_module: nn dispatch: CPU: slow_conv2d_forward_cpu - CUDA: legacy::cuda::_thnn_conv2d_forward + CUDA: slow_conv2d_forward_cuda - func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn @@ -10254,6 +10272,14 @@ dispatch: CPU, CUDA: linalg_lstsq_out +# torch.linalg.matmul, alias for torch.matmul +- func: linalg_matmul(Tensor self, Tensor other) -> Tensor + python_module: linalg + variants: function + +- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + - func: linalg_slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) python_module: linalg variants: function diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index bf5c596a9e0d2..3c0d79acac18c 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -563,6 +563,8 @@ at::Tensor PackedConvWeightsQnnp::apply_impl( const at::Tensor& act, double output_scale, int64_t output_zero_point) { + // QNNPack is not thread safe + std::lock_guard lock(qnnp_mutex_); const std::string func_name = transpose() ? "quantized::conv_transpose" : "quantized::conv"; TORCH_CHECK(!(kReluFused && transpose()), diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index dff28b141f6b1..87294c11adda0 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -276,9 +276,8 @@ c10::intrusive_ptr> PackedConvWeightsQnnp< // during the first invocation of operator run. Refer to qconv.cpp for more // details. TODO Update to actually call pre-pack here once bias is removed // from pre-packing step. - c10::intrusive_ptr> ret_ptr = - c10::make_intrusive>( - PackedConvWeightsQnnp{ + auto ret_ptr = + c10::intrusive_ptr>::make( nullptr, /* PrePackConvWeights */ weight_contig, /* int8_t weight */ bias_fp32.contiguous(), /* fp32 bias */ @@ -289,10 +288,10 @@ c10::intrusive_ptr> PackedConvWeightsQnnp< groups, transpose, c10::nullopt, /* input_scale */ - {kernel_h, kernel_w}, + std::vector{kernel_h, kernel_w}, w_scales, std::move(w_zero_points), - is_per_channel}); + is_per_channel); return ret_ptr; } diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 7adf05a1782ce..6aae3ba02ae09 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -141,9 +141,10 @@ at::Tensor& embedding_lookup_fallback_impl( } template -at::Tensor& embedding_bag_4bit_impl( +at::Tensor& embedding_bag_nbit_impl( at::Tensor& output, const at::Tensor& weight, + const int bit_width, const at::Tensor& indices, const at::Tensor& offsets, bool pruned_weights, @@ -174,8 +175,9 @@ at::Tensor& embedding_bag_4bit_impl( const auto weight_sizes = weight.sizes(); const int64_t weight_size = weight_sizes[1]; + int NUM_ELEM_PER_BYTE = 8 / bit_width; const int64_t D = - (weight_size - 4) * 2; // NB: 2-byte fp16 scale and 2-byte zero_offset + (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset const int64_t M = offsets.sizes()[0]; int64_t output_size = M - 1; @@ -211,7 +213,7 @@ at::Tensor& embedding_bag_4bit_impl( if (!pruned_weights || fallback_to_no_sparse) { // Generate the fbgemm kernel auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit( - /*bit rate=*/4, + /*bit rate=*/bit_width, /*block size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), /*normalize_by_lengths=*/false, @@ -234,11 +236,13 @@ at::Tensor& embedding_bag_4bit_impl( TORCH_CHECK( success, - "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for 4-bit input"); + "FBGEMM GenerateEmbeddingSpMDMNBit kernel failed for ", + bit_width, + "-bit input"); } else { auto kernel = fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse( - /*bit rate=*/4, + /*bit rate=*/bit_width, /*block_size=*/block_size, /*has weights=*/per_sample_weights_.has_value(), /*normalize_by_lengths=*/false, @@ -260,11 +264,14 @@ at::Tensor& embedding_bag_4bit_impl( /*compressed_indices_table=*/compressed_indices_mapping_data); TORCH_CHECK( success, - "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for 4-bit input"); + "FBGEMM GenerateEmbeddingSpMDMNBitRowWiseSparse kernel failed for ", + bit_width, + "-bit input"); } return output; #else - return embedding_lookup_fallback_impl( + if (bit_width == 4) { + return embedding_lookup_fallback_impl( weight, indices, offsets, @@ -275,6 +282,19 @@ at::Tensor& embedding_bag_4bit_impl( output_size, include_last_offset, (pruned_weights && !fallback_to_no_sparse)); + } + // bit_width == 2 + return embedding_lookup_fallback_impl( + weight, + indices, + offsets, + per_sample_weights_, + compressed_indices_mapping, + output, + D, + output_size, + include_last_offset, + (pruned_weights && !fallback_to_no_sparse)); #endif } @@ -519,9 +539,10 @@ at::Tensor& embedding_bag_byte_helper( is_embedding_op); } -at::Tensor& embedding_bag_4bit_helper( +at::Tensor& _embedding_bag_nbit_helper( at::Tensor& output, const at::Tensor& weight, + const int bit_width, const at::Tensor& indices, const c10::optional& offsets_in, bool pruned_weights, @@ -529,6 +550,10 @@ at::Tensor& embedding_bag_4bit_helper( const c10::optional& compressed_indices_mapping, bool include_last_offset) { c10::MaybeOwned offsets; + TORCH_CHECK( + bit_width == 4 || bit_width == 2, + "qembedding/qembedding_bag operator supports bit_width 2 or 4, got ", + bit_width); TORCH_CHECK( indices.dim() == 1 || indices.dim() == 2, "qembedding/qembedding_bag operator supports 1 or 2d indices, got ", @@ -539,14 +564,14 @@ at::Tensor& embedding_bag_4bit_helper( if (indices.dim() == 2) { TORCH_CHECK( !offsets_in.has_value(), - "embedding_bag_4bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); + "embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences."); offsets = c10::MaybeOwned::owned(at::arange( 0, indices.numel(), indices.sizes()[1], indices.scalar_type())); } else { TORCH_CHECK( offsets_in.has_value(), - "embedding_bag_4bit operator expects offsets to be set for 1D indices."); + "embedding_bag_4bit/embedding_bag_2bit operator expects offsets to be set for 1D indices."); offsets = c10::MaybeOwned::borrowed(offsets_in.value()); } @@ -568,9 +593,10 @@ at::Tensor& embedding_bag_4bit_helper( // Using helper function to support different type combination without the // need to cast, which can be additional performance overhead if (indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kInt) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -579,9 +605,10 @@ at::Tensor& embedding_bag_4bit_helper( include_last_offset); } else if ( indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -590,9 +617,10 @@ at::Tensor& embedding_bag_4bit_helper( include_last_offset); } else if ( indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) { - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -600,9 +628,10 @@ at::Tensor& embedding_bag_4bit_helper( compressed_indices_mapping, include_last_offset); } - return embedding_bag_4bit_impl( + return embedding_bag_nbit_impl( output, weight, + bit_width, indices, *offsets, pruned_weights, @@ -650,9 +679,10 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( } auto output = at::empty({0}, packed_w.options().dtype(at::kFloat)); - return embedding_bag_4bit_helper( + return _embedding_bag_nbit_helper( output, packed_w, + 4, indices, offsets_in, pruned_weights, @@ -709,9 +739,44 @@ Tensor& embedding_bag_4bit_rowwise_offsets_out( per_sample_weights_.value().scalar_type(), " instead") } - return embedding_bag_4bit_helper( + return _embedding_bag_nbit_helper( + output, + weight, + 4, + indices, + offsets_in, + pruned_weights, + per_sample_weights_.has_value() + ? per_sample_weights_.value().to(at::kFloat) + : per_sample_weights_, + compressed_indices_mapping, + include_last_offset); +} + +Tensor& embedding_bag_2bit_rowwise_offsets_out( + Tensor& output, + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + + if (per_sample_weights_.has_value()) { + TORCH_CHECK( + (per_sample_weights_.value().scalar_type() == at::kFloat || + per_sample_weights_.value().scalar_type() == at::kHalf), + "Expect fp32 or fp16 weights, but found", + per_sample_weights_.value().scalar_type(), + " instead") + } + return _embedding_bag_nbit_helper( output, weight, + 2, indices, offsets_in, pruned_weights, @@ -784,6 +849,33 @@ Tensor embedding_bag_4bit_rowwise_offsets( return output; } +Tensor embedding_bag_2bit_rowwise_offsets( + const Tensor& weight, + const Tensor& indices, + const c10::optional& offsets_in, + const bool /* scale_grad_by_freq */, + const int64_t /* mode */, + bool pruned_weights, + const c10::optional& per_sample_weights_, + const c10::optional& compressed_indices_mapping, + bool include_last_offset) { + + auto output = create_empty_from(weight, at::kFloat); + embedding_bag_2bit_rowwise_offsets_out( + output, + weight, + indices, + offsets_in, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights_, + compressed_indices_mapping, + include_last_offset + ); + return output; +} + template class QEmbeddingBag final { public: @@ -869,6 +961,9 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_rowwise_offsets"), embedding_bag_4bit_rowwise_offsets); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_rowwise_offsets"), + embedding_bag_2bit_rowwise_offsets); } } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 5d9abce940f58..614e274b5493d 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -122,7 +124,6 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( namespace at { namespace native { -namespace { // Note - This is a temporary pack function for embedding bag which quantizes // and packs the float weight tensor. In the next step it will be replaced by a @@ -184,7 +185,7 @@ namespace { // // [[50. , 60.00000035], // [70. , 80.00000035]]]) -Tensor qembeddingbag_byte_prepack(const Tensor& weight) { +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { // The "last" dimension of an N-Dimensioned batch of embedding bags is // quantization channel. E.g. for a 2D embedding bag, this has // [ row, col ] dimensions, for batched of embedding bags, dimensions might be @@ -208,17 +209,12 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { const int32_t embedding_cols = weight_sizes[cols_dim]; // Add 8 bytes per column to store FP32 scale and zero_point per row. const int32_t output_columns = embedding_cols + 2 * sizeof(float); - Tensor weight_contig = weight.contiguous(weight.suggest_memory_format()); + const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format()); // Adjust output dimensions to account for FP32 scale and zero_points. std::vector output_shape = weight_sizes.vec(); output_shape[cols_dim] = output_columns; - - // Allocate output packed weights - auto output = at::empty( - output_shape, - weight_contig.options().dtype(at::kByte), - weight_contig.suggest_memory_format()); + at::native::resize_(output, output_shape, c10::nullopt); auto* output_data = output.data_ptr(); #ifdef USE_FBGEMM @@ -246,10 +242,9 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { } #else - const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half - ? weight_contig.to(at::ScalarType::Float) - : weight_contig; - const auto weight_data = float_weight.data_ptr(); + const auto weight_data = weight_contig->scalar_type() == at::ScalarType::Half + ? weight_contig->to(at::ScalarType::Float).data_ptr() + : weight_contig->data_ptr(); constexpr float kEpsilon = 1e-8f; for (auto row: c10::irange(embedding_rows)) { const float* input_row = weight_data + row * embedding_cols; @@ -276,6 +271,21 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { return output; } +Tensor qembeddingbag_byte_prepack(const Tensor& weight) { + const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format()); + auto output = at::detail::empty_cpu( + {0}, + at::kByte, + weight_contig->layout(), + weight_contig->device(), + c10::nullopt, + c10::nullopt); + qembeddingbag_byte_prepack_out(output, weight); + return output; +} + +namespace { + // TODO: Extend support to N-D batched embeddings, similar to qembeddingbag_byte_prepack Tensor _qembeddingbag_nbit_prepack_helper( const Tensor& weight, diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h new file mode 100644 index 0000000000000..c52cbae4f2c80 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -0,0 +1,11 @@ +#include + +namespace at { +namespace native { + +Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); + +Tensor qembeddingbag_byte_prepack(const Tensor& weight); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index da64197fb4577..3331a0387111c 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -349,6 +349,12 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) { TORCH_INTERNAL_ASSERT( runStatus == pytorch_qnnp_status_success, "failed to run QNNPACK Linear operator"); + + // Call the relu operator here until qlinear dynamic in QNNPACK + // supports it natively. + if (ReluFused) { + output.relu_(); + } return output; } @@ -445,8 +451,14 @@ class QLinearDynamicFp16 final { TORCH_CHECK( fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); - TORCH_INTERNAL_ASSERT(!ReluFused); - return packed_weight->apply_dynamic(std::move(input)); + auto output = packed_weight->apply_dynamic(std::move(input)); + + // Call the relu operator here until fp16 linear dynamic in FBGEMM + // supports it natively. + if (ReluFused) { + output.relu_(); + } + return output; } #else // USE_FBGEMM static at::Tensor run( @@ -465,6 +477,7 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu_dynamic_fp16"), TORCH_FN(QLinearDynamicFp16::run)); } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h index 0f329296bc18b..62fdef2cdf9b2 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h @@ -8,7 +8,6 @@ #pragma once #include -#include #include #include #include @@ -33,25 +32,7 @@ typedef struct BCSRMatrix { #endif uint32_t col_block_size; // input features block size uint32_t row_block_size; // output features block size - void print() { - std::cout << "row block size:" << row_block_size << std::endl; - std::cout << "col block size:" << col_block_size << std::endl; - std::cout << "row ptr\n"; - for (const auto& t : row_values) { - std::cout << t << ", "; - } - std::cout << std::endl; - std::cout << "col indices\n"; - for (const auto& t : col_indices) { - std::cout << t << ", "; - } - std::cout << std::endl; - std::cout << "Actual values\n"; - for (const auto& t : values) { - std::cout << (uint32_t)t << ", "; - } - std::cout << std::endl; - } + void print() const; } BCSRMatrix; std::unique_ptr generateBlockCSRMatrix( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc index ca694df3aba45..6a6134023bfc8 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/pack_block_sparse.cc @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ #include +#include #include @@ -78,4 +79,24 @@ std::unique_ptr generateBlockCSRMatrix( bcsr_mat.col_block_size = col_block_size; return bcsr_mat_ptr; } + +void BCSRMatrix::print() const { + std::cout << "row block size:" << row_block_size << std::endl; + std::cout << "col block size:" << col_block_size << std::endl; + std::cout << "row ptr\n"; + for (const auto& t : row_values) { + std::cout << t << ", "; + } + std::cout << std::endl; + std::cout << "col indices\n"; + for (const auto& t : col_indices) { + std::cout << t << ", "; + } + std::cout << std::endl; + std::cout << "Actual values\n"; + for (const auto& t : values) { + std::cout << (uint32_t)t << ", "; + } + std::cout << std::endl; +} } // namsepace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h index 161be5a2f8fa3..91ede920b87e2 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h @@ -292,6 +292,7 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase { } private: + std::mutex qnnp_mutex_; template at::Tensor apply_impl( const at::Tensor& input, diff --git a/aten/src/ATen/native/quantized/cuda/embedding_bag.cu b/aten/src/ATen/native/quantized/cuda/embedding_bag.cu index 6d44ce0f41873..55b0b0d4f36d0 100644 --- a/aten/src/ATen/native/quantized/cuda/embedding_bag.cu +++ b/aten/src/ATen/native/quantized/cuda/embedding_bag.cu @@ -56,15 +56,15 @@ dequantize_intx(uint32_t packedVals, float2 scale_bias, uint8_t offset_bits) { template __forceinline__ __device__ void -accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias) { +accumulate_packed_intx(float4* acc, uint32_t packedVals, float2 scale_bias, float sample_weight) { constexpr uint8_t dims_per_byte = 8 / bits_per_dim; for (uint8_t i = 0; i < dims_per_byte; i++) { float4 res = dequantize_intx(packedVals, scale_bias, 4 * bits_per_dim * i /* offset_bits */); // Accumulate in float32. - acc[i].x += res.x; - acc[i].y += res.y; - acc[i].z += res.z; - acc[i].w += res.w; + acc[i].x += (res.x * sample_weight); + acc[i].y += (res.y * sample_weight); + acc[i].z += (res.z * sample_weight); + acc[i].w += (res.w * sample_weight); } } @@ -77,7 +77,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( const PackedTensorAccessor32 indices, const PackedTensorAccessor32 offsets, const bool /* pruned_weights */, - const c10::optional& per_sample_weights_, + const PackedTensorAccessor32 per_sample_weights_, const c10::optional& compressed_indices_mapping, const bool include_last_offset, PackedTensorAccessor32 output) { @@ -96,6 +96,8 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( const int32_t D_bytes = weight.size(1); + bool use_per_sample = per_sample_weights_.size(0) > 0; + int64_t indices_start = offsets[t * B + b]; int64_t indices_end; if (include_last_offset) { @@ -124,6 +126,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( } for (int32_t l = indices_start; l < indices_end; ++l) { int64_t idx = indices[l]; + float sample_weight = use_per_sample ? per_sample_weights_[l] : 1.0f; const uint8_t* __restrict__ row = &weights[idx * D_bytes]; float2 scale_bias; if (fp32_scale_bias) { @@ -138,7 +141,7 @@ __global__ void embedding_bag_nbits_rowwise_offsets_kernel( uint32_t v0 = reinterpret_cast(&row[byte_offset])[0]; - accumulate_packed_intx(accumulator, v0, scale_bias); + accumulate_packed_intx(accumulator, v0, scale_bias, sample_weight); } @@ -204,9 +207,11 @@ at::Tensor& embedding_bag_byte_impl( const int D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias const int64_t M = offsets.sizes()[0]; TORCH_CHECK(D % 4 == 0); - TORCH_CHECK( - !per_sample_weights_.has_value(), - "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); + if(per_sample_weights_.has_value()) { + TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat, + "Per sample weights expected scalar type ", at::kFloat, " but got ", + per_sample_weights_.value().scalar_type()); + } TORCH_CHECK( !compressed_indices_mapping.has_value(), "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); @@ -215,6 +220,13 @@ at::Tensor& embedding_bag_byte_impl( int64_t output_size = include_last_offset ? M - 1 : M; + at::Tensor sample_weights; + if (per_sample_weights_.has_value()) { + sample_weights = per_sample_weights_.value(); + } else { + sample_weights = create_empty_from(output, kFloat); + } + const std::vector shape = {output_size, D}; at::native::resize_(output, shape, c10::nullopt); AT_DISPATCH_INDEX_TYPES( @@ -228,7 +240,7 @@ at::Tensor& embedding_bag_byte_impl( indices.packed_accessor32(), offsets.packed_accessor32(), false /* pruned_weights */, - per_sample_weights_, + sample_weights.packed_accessor32(), compressed_indices_mapping, include_last_offset, output.packed_accessor32()); @@ -377,9 +389,11 @@ at::Tensor& embedding_bag_4bit_impl( const int D = 2*(weight_sizes[1] - 4); // NB: -4 to account for scale and bias @fp16 const int64_t M = offsets.sizes()[0]; TORCH_CHECK(D % 8 == 0); - TORCH_CHECK( - !per_sample_weights_.has_value(), - "Per sample weights not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); + if(per_sample_weights_.has_value()) { + TORCH_CHECK(per_sample_weights_.value().scalar_type() == at::kFloat, + "Per sample weights expected scalar type ", at::kFloat, " but got ", + per_sample_weights_.value().scalar_type()); + } TORCH_CHECK( !compressed_indices_mapping.has_value(), "Compressed indices mapping not yet implemented for embedding_bag_byte_rowwise_offsets_cuda"); @@ -388,6 +402,13 @@ at::Tensor& embedding_bag_4bit_impl( int64_t output_size = include_last_offset ? M - 1 : M; + at::Tensor sample_weights; + if (per_sample_weights_.has_value()) { + sample_weights = per_sample_weights_.value(); + } else { + sample_weights = create_empty_from(output, kFloat); + } + const std::vector shape = {output_size, D}; at::native::resize_(output, shape, c10::nullopt); AT_DISPATCH_INDEX_TYPES( @@ -401,7 +422,7 @@ at::Tensor& embedding_bag_4bit_impl( indices.packed_accessor32(), offsets.packed_accessor32(), false /* pruned_weights */, - per_sample_weights_, + sample_weights.packed_accessor32(), compressed_indices_mapping, include_last_offset, output.packed_accessor32()); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 7cdb5cb35817a..3dcf75b1ccb32 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -128,6 +128,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor")); @@ -141,6 +142,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index d5bc66b7fb23b..c55ea3b540b5a 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp new file mode 100644 index 0000000000000..b13e7fe595d8f --- /dev/null +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp @@ -0,0 +1,74 @@ +/* +Functions here use deprecated cuSPARSE API that was removed in CUDA 11. +This file will be removed eventually. +*/ +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, const Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, const Tensor& dense) { + TORCH_INTERNAL_ASSERT(nnz > 0); + + // No half support, so we don't have to use CUDATypeConversion + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + values.scalar_type(), "addmm_sparse_cuda", [&] { + scalar_t cast_beta = beta.to(); + scalar_t cast_alpha = alpha.to(); + Tensor r__; + if (cast_beta == scalar_t(0)) { + r_.zero_(); + } else if (!at::sparse::is_same_tensor(t, r_)) { + r_.copy_(t); + } + if (r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) { + r__ = r_; + } else { + // Note: This storage arrangement is preferred due to most of the CUDA kernels handle only contiguous tensors + r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous); + r__.transpose_(0, 1); + } + TORCH_INTERNAL_ASSERT(r__.transpose(-1, -2).is_contiguous()); + Tensor dense_; + char transpose_dense; + if (dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) { + transpose_dense = 'n'; + dense_ = dense; + } else if (dense.stride(1) == 1 && dense.stride(0) == dense.size(1)) { + transpose_dense = 't'; + dense_ = dense; + } else { + transpose_dense = 't'; + dense_ = dense.contiguous(); + } + + sparse::cuda::csrmm2( + 'n', + transpose_dense, + m, + n, + k, + nnz, + cast_alpha, + values.data_ptr(), + crow_indices.data_ptr(), + col_indices.data_ptr(), + dense_.data_ptr(), + (transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)), + cast_beta, + r__.data_ptr(), + r__.stride(1)); + + if (!at::sparse::is_same_tensor(r__, r_)) { + r_.copy_(r__); + } + } + ); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.h b/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.h new file mode 100644 index 0000000000000..67eaffb13a75c --- /dev/null +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +/* +Functions here use deprecated cuSPARSE API that was removed in CUDA 11. +Here only 32-bit indices sparse indices are supported. +This file will be removed eventually. +*/ + +namespace at { +namespace native { + +void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, const Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, const Tensor& dense); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp similarity index 99% rename from aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu rename to aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index dd03e2bfeacbe..db0088a084c6d 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include @@ -14,7 +14,7 @@ // Using these APIs in any other systems will result in compile-time or run-time failures. // Their support will be extended in the next releases. -#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301)) +#if defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301)) #define IS_SPMM_AVAILABLE() 1 #else #define IS_SPMM_AVAILABLE() 0 diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.h similarity index 100% rename from aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cuh rename to aten/src/ATen/native/sparse/cuda/SparseCUDABlas.h diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 511e69ef4b408..0331f5e4d932e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -5,8 +5,9 @@ #include #include #include +#include #include -#include +#include #include #include #include @@ -50,64 +51,6 @@ namespace { } } -void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& crow_indices, Tensor& col_indices, Tensor& values, const Tensor& dense) { - TORCH_INTERNAL_ASSERT(nnz > 0); - - // No half support, so we don't have to use CUDATypeConversion - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( - values.scalar_type(), "addmm_sparse_cuda", [&] { - scalar_t cast_beta = beta.to(); - scalar_t cast_alpha = alpha.to(); - Tensor r__; - if (cast_beta == scalar_t(0)) { - r_.zero_(); - } else if (!is_same_tensor(t, r_)) { - r_.copy_(t); - } - if(r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) { - r__ = r_; - } else { - // Note: This storage arrangement is preferred due to most of the CUDA kernels handle only contiguous tensors - r__ = r_.transpose(0, 1).clone(at::MemoryFormat::Contiguous); - r__.transpose_(0, 1); - } - Tensor dense_; - char transpose_dense; - if(dense.stride(0) == 1 && dense.stride(1) == dense.size(0)) { - transpose_dense = 'n'; - dense_ = dense; - } else if(dense.stride(1) == 1 && dense.stride(0) == dense.size(1)) { - transpose_dense = 't'; - dense_ = dense; - } else { - transpose_dense = 't'; - dense_ = dense.contiguous(); - } - - sparse::cuda::csrmm2( - 'n', - transpose_dense, - m, - n, - k, - nnz, - cast_alpha, - values.data_ptr(), - crow_indices.data_ptr(), - col_indices.data_ptr(), - dense_.data_ptr(), - (transpose_dense == 'n' ? dense_.stride(1) : dense_.stride(0)), - cast_beta, - r__.data_ptr(), - r__.stride(1)); - - if (!is_same_tensor(r__, r_)) { - r_.copy_(r__); - } - } - ); -} - // NB: Deleted spaddcmul (aka addcmul_, but not actually wired up), spaddcdiv (not // wired at all) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cuh index 1a99e818e1bad..9448b2aa46b6c 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cuh @@ -6,8 +6,6 @@ namespace at { namespace native { -void s_addmm_out_csr_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& crow_indices, Tensor& col_indices, Tensor& values, const Tensor& dense); - void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, const Scalar& beta, const Tensor& t, const Scalar& alpha, Tensor& indices, Tensor& values, const Tensor& dense); }} // namespace at::native diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu index ea765e076fb04..b21d892fcdf84 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu @@ -19,7 +19,8 @@ #include #include -#include +#include +#include #include #include diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 2d041de6ea411..d5f31a1980bac 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl b/aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl new file mode 100644 index 0000000000000..f947e78f1843d --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/leaky_relu.glsl @@ -0,0 +1,28 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0) uniform PRECISION restrict writeonly image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + float negative_slope; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 inval = texelFetch(uInput, pos, 0); + const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f))); + const vec4 positive_values = vec4(1.0) - negative_values; + const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values; + const vec4 outval = inval * mask; + imageStore(uOutput, pos, outval); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl b/aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl new file mode 100644 index 0000000000000..345e66942c155 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/leaky_relu_.glsl @@ -0,0 +1,27 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba16f) uniform PRECISION restrict image3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION restrict Block { + ivec4 size; + float negative_slope; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (all(lessThan(pos, uBlock.size.xyz))) { + const vec4 inval = imageLoad(uOutput, pos); + const vec4 negative_values = vec4(lessThan(inval, vec4(0.0f))); + const vec4 positive_values = vec4(1.0) - negative_values; + const vec4 mask = negative_values * vec4(uBlock.negative_slope) + positive_values; + const vec4 outval = inval * mask; + imageStore(uOutput, pos, outval); + } +} diff --git a/aten/src/ATen/native/vulkan/ops/Clamp.cpp b/aten/src/ATen/native/vulkan/ops/Clamp.cpp index c6f046e84fd17..a6e65607fb07c 100644 --- a/aten/src/ATen/native/vulkan/ops/Clamp.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clamp.cpp @@ -207,7 +207,7 @@ Tensor& activation_( TORCH_CHECK( self.is_vulkan(), - "Vulkan: In-place clamp is only supported on Vulkan tensors."); + "Vulkan: In-place operator is only supported on Vulkan tensors."); vTensor& v_self = convert(self); @@ -289,9 +289,10 @@ Tensor& hardsigmoid_(Tensor& self) { return ops::activation_(self, VK_KERNEL(hardsigmoid_)); } -Tensor hardshrink( +Tensor activation_scalar( const Tensor& self_arg, - const Scalar& lambd) { + const Scalar& scalar_arg, + const api::Shader::Descriptor& shader_descriptor) { api::Context* const context = api::context(); const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); @@ -310,11 +311,11 @@ Tensor hardshrink( const struct Block final { uvec3 extents; uint32_t _; - float lambd; + float scalar_value; } block { v_output.extents(), 0u, - lambd.to(), + scalar_arg.to(), }; context->dispatch( @@ -324,7 +325,7 @@ Tensor hardshrink( VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }, - VK_KERNEL(hardshrink), + shader_descriptor, v_output.extents(), context->gpu().adapter->local_work_group_size(), // Write-only access bypasses synchronization but inserts appropriate @@ -351,14 +352,15 @@ Tensor hardshrink( return convert(v_output); } -Tensor& hardshrink_( +Tensor& activation_scalar_( Tensor& self, - const Scalar& lambd) { + const Scalar& scalar_arg, + const api::Shader::Descriptor& shader_descriptor) { api::Context* const context = api::context(); TORCH_CHECK( self.is_vulkan(), - "Vulkan: In-place hardshrink is only supported on Vulkan tensors."); + "Vulkan: In-place operator is only supported on Vulkan tensors."); vTensor& v_self = convert(self); @@ -369,11 +371,11 @@ Tensor& hardshrink_( const struct Block final { uvec3 extents; uint32_t _; - float lambd; + float scalar_value; } block { v_self.extents(), 0u, - lambd.to(), + scalar_arg.to(), }; context->dispatch( @@ -382,7 +384,7 @@ Tensor& hardshrink_( VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, }, - VK_KERNEL(hardshrink_), + shader_descriptor, v_self.extents(), context->gpu().adapter->local_work_group_size(), // Read-Write access triggers an async synchronization if necessory @@ -404,6 +406,30 @@ Tensor& hardshrink_( return self; } +Tensor hardshrink( + const Tensor& self_arg, + const Scalar& lambd) { + return ops::activation_scalar(self_arg, lambd, VK_KERNEL(hardshrink)); +} + +Tensor& hardshrink_( + Tensor& self, + const Scalar& lambd) { + return ops::activation_scalar_(self, lambd, VK_KERNEL(hardshrink_)); +} + +Tensor leaky_relu( + const Tensor& self_arg, + const Scalar& negative_slope) { + return ops::activation_scalar(self_arg, negative_slope, VK_KERNEL(leaky_relu)); +} + +Tensor& leaky_relu_( + Tensor& self, + const Scalar& negative_slope) { + return ops::activation_scalar_(self, negative_slope, VK_KERNEL(leaky_relu_)); +} + Tensor sigmoid(const Tensor& self) { return ops::activation(self, VK_KERNEL(sigmoid)); } @@ -427,12 +453,14 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::clamp_"), TORCH_FN(clamp_)); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid"), hardsigmoid); m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), hardsigmoid_); - m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); - m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_)); + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), hardshrink); + m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), hardshrink_); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), hardswish); m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), hardswish_); m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), hardtanh); m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), hardtanh_); + m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), leaky_relu); + m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), leaky_relu_); m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), sigmoid); m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid_"), sigmoid_); m.impl(TORCH_SELECTIVE_NAME("aten::tanh"), tanh); diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index 8c5d99a242196..f46052d9c5ef6 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -425,6 +425,21 @@ Tensor conv2d_clamp_run( return op_context->run(input); } +// Op is registered to have Any argument as we plan to reuse it for prepacked conv2d of other backends +std::tuple, IntArrayRef, IntArrayRef, IntArrayRef, int64_t> +unpack_prepacked_sizes_conv2d(const IValue& ivalue) { + auto op_context = ivalue.toCustomClass(); + const auto tuple = op_context->unpack(); + const auto& bias = std::get<1>(tuple); + return std::make_tuple( + std::get<0>(tuple).sizes(), + (bias && bias->defined()) ? c10::optional(bias->sizes()) : c10::nullopt, + std::get<2>(tuple), + std::get<3>(tuple), + std::get<4>(tuple), + std::get<5>(tuple)); +} + Tensor conv2d_transpose_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context) { diff --git a/aten/src/ATen/native/xnnpack/Convolution.h b/aten/src/ATen/native/xnnpack/Convolution.h index 403f26cdec70e..b89059de2c615 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.h +++ b/aten/src/ATen/native/xnnpack/Convolution.h @@ -39,6 +39,9 @@ Tensor conv2d_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context); +std::tuple, IntArrayRef, IntArrayRef, IntArrayRef, int64_t> +unpack_prepacked_sizes_conv2d(const IValue& ivalue); + Tensor conv2d_transpose_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context); diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp index 9a459b660d6fb..19c474f34cef9 100644 --- a/aten/src/ATen/native/xnnpack/Linear.cpp +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -180,6 +180,16 @@ Tensor linear_clamp_run( return op_context->run(input); } +std::tuple> +unpack_prepacked_sizes_linear(const IValue& ivalue) { + auto op_context = ivalue.toCustomClass(); + const auto tuple = op_context->unpack(); + const auto& bias = std::get<1>(tuple); + return std::make_tuple( + std::get<0>(tuple).sizes(), + (bias && bias->defined()) ? c10::optional(bias->sizes()) : c10::nullopt); +} + } // namespace linear } // namespace internal diff --git a/aten/src/ATen/native/xnnpack/Linear.h b/aten/src/ATen/native/xnnpack/Linear.h index 3e4df0466d261..d25f63bafa739 100644 --- a/aten/src/ATen/native/xnnpack/Linear.h +++ b/aten/src/ATen/native/xnnpack/Linear.h @@ -20,6 +20,9 @@ c10::intrusive_ptr createLinearClampPrePackOpContext( Tensor linear_clamp_run(const Tensor& input, const c10::intrusive_ptr& op_context); +std::tuple> +unpack_prepacked_sizes_linear(const IValue& ivalue); + ContextLinear create( const Tensor& weight, const c10::optional& bias, diff --git a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp index 03ac612aa12d0..f09c2dc22a39c 100644 --- a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp +++ b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp @@ -80,7 +80,10 @@ TORCH_LIBRARY(xnnpack, m) { } +// Registration using the TORCH_LIBRARY def gives dispatching errors when there is no tensor input TORCH_LIBRARY(prepacked, m) { + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_conv2d(Any W_prepack) -> (int[], int[]?, int[], int[], int[], int)"), [](const IValue& inp) { return internal::convolution2d::unpack_prepacked_sizes_conv2d(inp);}); + m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_linear(Any W_prepack) -> (int[], int[]?)"), [](const IValue& inp) { return internal::linear::unpack_prepacked_sizes_linear(inp);}); m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext")); m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext")); diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index c702a68063c31..16caf5326c711 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -33,7 +33,6 @@ #include #include $extra_cuda_headers -$legacy_th_headers $external_backend_headers $namespaced_headers @@ -44,6 +43,8 @@ namespace at { // at namespace already. namespace { +${dispatch_helpers} + ${dispatch_anonymous_definitions} TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index be14980fb2d14..95312ff5d10f3 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -52,6 +52,7 @@ struct Node; namespace at { +class OptionalTensorRef; class Tensor; using TensorList = ArrayRef; @@ -96,6 +97,7 @@ class TORCH_API Tensor { explicit Tensor(unsafe_borrow_t, const Tensor& rhs) : impl_(c10::intrusive_ptr::reclaim(rhs.impl_.get())) {} friend MaybeOwnedTraits; + friend OptionalTensorRef; public: Tensor(){}; @@ -492,6 +494,12 @@ class TORCH_API Tensor { return impl_->is_mlc(); } + /// Returns if a `Tensor` is ort tensor. + bool is_ort() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ort(); + } + /// Returns if a `Tensor` is vulkan tensor. bool is_vulkan() const { // NB: this is not a native function to avoid dispatching overhead. diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index 531507e96697e..9b215a90ae74a 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -6,6 +6,11 @@ #include +// NB. These tests use the ORT dispatch key to test backend dispatching +// machinery, but these tests are not specific to ORT at all. The ORT +// backend is fully out-of-tree, so it's safe to use this key for +// in-tree tests. + using namespace at; static int test_int; @@ -17,16 +22,16 @@ Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::op Storage( Storage::use_byte_size_t(), 0, - at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), + at::DataPtr(nullptr, Device(DeviceType::ORT, 1)), nullptr, false), - DispatchKey::MSNPU, + DispatchKey::ORT, caffe2::TypeMeta::Make()); return Tensor(std::move(tensor_impl)); } Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) { - auto out = empty({5, 5}, at::kMSNPU); // Don't return self as-is + auto out = empty({5, 5}, at::kORT); // Don't return self as-is test_int = 2; return out; } @@ -42,28 +47,28 @@ Tensor empty_strided_override( return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt); } -TORCH_LIBRARY_IMPL(aten, MSNPU, m) { +TORCH_LIBRARY_IMPL(aten, ORT, m) { m.impl("aten::empty.memory_format", empty_override); m.impl("aten::empty_strided", empty_strided_override); m.impl("aten::add.Tensor", add_override); } TEST(BackendExtensionTest, TestRegisterOp) { - Tensor a = empty({5, 5}, at::kMSNPU); - ASSERT_EQ(a.device().type(), at::kMSNPU); + Tensor a = empty({5, 5}, at::kORT); + ASSERT_EQ(a.device().type(), at::kORT); ASSERT_EQ(a.device().index(), 1); ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make()); ASSERT_EQ(test_int, 1); - Tensor b = empty_like(a, at::kMSNPU); - ASSERT_EQ(b.device().type(), at::kMSNPU); + Tensor b = empty_like(a, at::kORT); + ASSERT_EQ(b.device().type(), at::kORT); ASSERT_EQ(b.device().index(), 1); ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make()); add(a, b); ASSERT_EQ(test_int, 2); - // Ensure that non-MSNPU operator still works + // Ensure that non-ORT operator still works Tensor d = empty({5, 5}, at::kCPU); ASSERT_EQ(d.device().type(), at::kCPU); } diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 915e267347170..3ae18390f8f6e 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -90,6 +90,18 @@ TEST(IValueTest, Basic) { ASSERT_EQ(complex_tuple.toTuple()->elements()[1], foo1); } +TEST(IValueTest, BasicStorage) { + at::Storage emptyStorage; + at::Storage nonemptyStorage(at::rand({3, 4}).storage()); + IValue ivEmpty(emptyStorage); + IValue ivNonempty(nonemptyStorage); + + ASSERT_TRUE(ivEmpty.isStorage()); + ASSERT_TRUE(ivNonempty.isStorage()); + ASSERT_EQ(emptyStorage.unsafeGetStorageImpl(), ivEmpty.toStorage().unsafeGetStorageImpl()); + ASSERT_EQ(nonemptyStorage.unsafeGetStorageImpl(), ivNonempty.toStorage().unsafeGetStorageImpl()); +} + TEST(IValueTest, ComplexDict) { typedef c10::complex c_type; c10::Dict m; @@ -102,21 +114,70 @@ TEST(IValueTest, ComplexDict) { ASSERT_EQ(m_.at(num1), 2 * num1); ASSERT_EQ(m_.at(num2), 2 * num2); } -static std::array makeSampleIValues() { - return { at::rand({3, 4}), "hello", 42, true, 1.5 }; -} -static std::array makeMoreSampleIValues() { - return { at::rand({3, 4}), "goodbye", 23, false, 0.5 }; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static std::array makeSampleIValues() { + return { + IValue(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::rand({3, 4}), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::rand({3, 4}).storage(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + 1.5, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + c10::complex(2.5, -0.5), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + 42, + true, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + std::make_tuple(23, "hello"), + "hello", + c10::make_intrusive(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + c10::List({1, 2, 3}), + c10::Dict(), + c10::make_intrusive(FloatType::get()), + c10::Device(c10::DeviceType::CPU, 0), + c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CPU, 0)), + c10::make_intrusive(c10::StrongTypePtr(nullptr, ClassType::create("class1", {})), 1), + }; } +// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) +static std::array makeMoreSampleIValues() { + return { + IValue(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::rand({3, 4}), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + at::rand({3, 4}).storage(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + 2.5, + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + c10::complex(2.7, -0.3), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + 43, + false, + std::make_tuple(1, "goodbye"), + "goodbye", + c10::make_intrusive(), + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + c10::List({4, 5, 6}), + c10::Dict(), + c10::make_intrusive(IntType::get()), + c10::Device(c10::DeviceType::CUDA, 2), + c10::Stream(c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, 1)), + c10::make_intrusive(c10::StrongTypePtr(nullptr, ClassType::create("class2", {})), 2), + };} + // IValue::operator== doesn't seem to work on Tensors. #define EXPECT_IVALUE_EQ(a, b) \ EXPECT_EQ((a).isTensor(), (b).isTensor()); \ if ((a).isTensor()) { \ - EXPECT_TRUE(a.toTensor().equal(b.toTensor())); \ + EXPECT_TRUE((a).toTensor().equal((b).toTensor())); \ } else { \ - EXPECT_EQ(a, b); \ + EXPECT_EQ((a), (b)); \ } TEST(IValueTest, Swap) { @@ -580,13 +641,31 @@ TEST(IValueTest, IdentityComparisonAndHashing) { ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size()); for (int ii = 0; ii < sampleIValues.size(); ++ii) { - // Constant strings will have the same pointer value. - if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) { - EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); - } else { - EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + if (sampleIValues[ii].isComplexDouble() || + sampleIValues[ii].isBlob() || + sampleIValues[ii].isList() || + sampleIValues[ii].isFuture() || + sampleIValues[ii].isStream() || + sampleIValues[ii].isObject() || + sampleIValues[ii].isGenericDict()) { + // Not hashable. + continue; + } + // Tuples may or may not have the same hash across instantiations. + if (!sampleIValues[ii].isTuple()) { + // Constant strings will have the same pointer value. + if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) { + EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash()) + << " at index " << ii; + } else { + EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash()) + << " at index " << ii; + } + } + if (!sampleIValues[ii].isNone() && !moreSampleIValues[ii].isNone()) { + EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash()) + << " at index " << ii; } - EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash()); } } @@ -656,5 +735,13 @@ TEST(IValueTest, ScalarBool) { EXPECT_TRUE(actual.toBool()); } +TEST(IValueTest, ToWeakAndBack) { + auto sampleInputs = makeSampleIValues(); + for (const auto& sample: sampleInputs) { + WeakIValue weak(sample); + EXPECT_IVALUE_EQ(sample, weak.lock()); + } +} + // TODO(gmagogsfm): Add type conversion test? } // namespace c10 diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 474aa36c40cca..d4b466aa920f2 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -942,7 +942,7 @@ TEST(VulkanAPITest, hardshrink) { } for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { - const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu = (at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; // between -10 and +10 const auto in_vulkan = in_cpu.vulkan(); const auto out_cpu = at::hardshrink(in_cpu, lambd_value); @@ -964,7 +964,7 @@ TEST(VulkanAPITest, hardshrink_) { } for (const auto lambd_value : {-4.2, -1.0, -0.42, 0.0, 0.42, 1.0, 4.2, 42.42}) { - const auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto cpu = (at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) - 0.5) * 20; // between -10 and +10 const auto vulkan = cpu.vulkan(); cpu.hardshrink(lambd_value); @@ -979,6 +979,49 @@ TEST(VulkanAPITest, hardshrink_) { } } +TEST(VulkanAPITest, leaky_relu) { + if (!at::is_vulkan_available()) { + return; + } + + for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) { + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::leaky_relu(in_cpu, negative_slope); + const auto out_vulkan = at::leaky_relu(in_vulkan, negative_slope); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); + } +} + +TEST(VulkanAPITest, leaky_relu_) { + if (!at::is_vulkan_available()) { + return; + } + + for (const auto negative_slope : {0.01, 0.001, 1.0, -0.001}) { + auto cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + auto vulkan = cpu.vulkan(); + + at::leaky_relu_(cpu, negative_slope); + at::leaky_relu_(vulkan, negative_slope); + + const auto check = almostEqual(cpu, vulkan.cpu()); + if (!check) { + showRtol(cpu, vulkan.cpu()); + } + + ASSERT_TRUE(check); + } +} + TEST(VulkanAPITest, hardswish) { if (!at::is_vulkan_available()) { return; diff --git a/aten/src/README.md b/aten/src/README.md index e3e01515afb0f..183ec09a97efd 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -7,7 +7,6 @@ multiple variants of the library, summarized here: * TH = TorcH * THC = TorcH Cuda * THCS = TorcH Cuda Sparse (now defunct) -* THCUNN = TorcH CUda Neural Network (see cunn) * THNN = TorcH Neural Network (now defunct) * THS = TorcH Sparse (now defunct) diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 786506027ea8f..ab7f72b2f41d4 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -17,10 +17,7 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCStorageCopy.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensor.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorCopy.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMath.cu ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathMagma.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathPairwise.cu - ${CMAKE_CURRENT_SOURCE_DIR}/THCTensorMathReduce.cu PARENT_SCOPE) install(FILES @@ -33,8 +30,6 @@ install(FILES THCTensor.h THCTensorCopy.h THCTensorCopy.hpp - THCTensorMath.h - THCApply.cuh THCReduceApplyUtils.cuh THCTensorMathReduce.cuh THCAsmUtils.cuh @@ -66,7 +61,6 @@ install(FILES THCNumerics.cuh THCTensorInfo.cuh THCTensorTypeUtils.cuh - THCTensorMathMagma.h THCThrustAllocator.cuh # See Note [TH abstraction violation] THCTensor.hpp @@ -86,10 +80,4 @@ install(FILES generic/THCStorageCopy.h generic/THCTensorCopy.cu generic/THCTensorCopy.h - generic/THCTensorMath.h - generic/THCTensorMath.cu - generic/THCTensorMathMagma.h - generic/THCTensorMathMagma.cpp - generic/THCTensorMathPairwise.h - generic/THCTensorMathPairwise.cu DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC/generic") diff --git a/aten/src/THC/THC.h b/aten/src/THC/THC.h index 717442db9eaa1..59e2f5de69fe0 100644 --- a/aten/src/THC/THC.h +++ b/aten/src/THC/THC.h @@ -11,6 +11,5 @@ #include #include -#include #endif diff --git a/aten/src/THC/THCApply.cuh b/aten/src/THC/THCApply.cuh deleted file mode 100644 index e424b2406ee3c..0000000000000 --- a/aten/src/THC/THCApply.cuh +++ /dev/null @@ -1,760 +0,0 @@ -#ifndef THC_APPLY_INC -#define THC_APPLY_INC - -#include -#include -#include -#include -#include -#include - -// -// This file contains pointwise operation functions and kernels that -// work on both contiguous and non-contiguous tensor arguments of -// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without -// copying or temporary storage. -// - -// Rearrange dimensions for pointwise operations so that strides are in -// decreasing order as much as possible, so that kernels have better memory -// access patterns. -// -// For example, consider a binary operation on two "transposed" 2-dim tensors: -// sizes: 256 512 -// aInfo->strides: 1 256 -// bInfo->strides: 1 256 -// -// Given this, each concurrent memory access inside kernelPointwiseApply2() is -// exactly 256 elements apart, resulting in poor performance. -// -// This function exchanges dimensions so that memory access is contiguous: -// sizes: 512 256 -// aInfo->strides: 256 1 -// bInfo->strides: 256 1 -// -// (Actually, it becomes even better because now collapseDims() can turn each -// input into one contiguous array.) -// -// In general, given M (<=3) TensorInfo's with N dimensions, we can view each -// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange -// strides[i] and [j] if -// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M) -// (exchanging them will benefit input #k), and -// (2) strides[i][k] <= strieds[j][k] for all k -// (exchanging them will not make any input worse). -template -void rearrangeDims(TensorInfo* aInfo, - TensorInfo* bInfo = nullptr, - TensorInfo* cInfo = nullptr) { - int numInfos = 1; - int dims = aInfo->dims; - IndexType *sizes[3] = { aInfo->sizes, }; - IndexType *strides[3] = { aInfo->strides, }; - - if (bInfo != nullptr) { - ++numInfos; - if (bInfo->dims != dims) return; - sizes[1] = bInfo->sizes; - strides[1] = bInfo->strides; - } - - if (cInfo != nullptr) { - ++numInfos; - if (cInfo->dims != dims) return; - sizes[2] = cInfo->sizes; - strides[2] = cInfo->strides; - } - - // Bail out if sizes do not match: we are using "deprecated pointwise - // behavior" among tensors of different shapes but same number of elements. - for (int i = 1; i < numInfos; ++i) { - for (int j = 0; j < dims; ++j) { - if (sizes[i][j] != sizes[0][j]) return; - } - } - - for (int i = 0; i < dims - 1; ++i) { - // No need to consider dimensions of size 1. - if (sizes[0][i] == 1) continue; - - for (int j = i + 1; j < dims; ++j) { - if (sizes[0][j] == 1) continue; - - // Compare the relative sizes of strides between dim #i and dim #j. - bool hasIncreasingStrides = false; - bool hasDecreasingStrides = false; - - for (int k = 0; k < numInfos; k++) { - IndexType stride_i = strides[k][i]; - IndexType stride_j = strides[k][j]; - if (stride_i < stride_j) { - hasIncreasingStrides = true; - } else if (stride_i > stride_j) { - hasDecreasingStrides = true; - } - } - - if (hasIncreasingStrides && !hasDecreasingStrides) { - for (int k = 0; k < numInfos; k++) { - IndexType size = sizes[k][i]; - sizes[k][i] = sizes[k][j]; - sizes[k][j] = size; - - IndexType stride = strides[k][i]; - strides[k][i] = strides[k][j]; - strides[k][j] = stride; - } - } - } - } -} - -// Threads per block for our apply kernel -// FIXME: use occupancy calculator instead -#define THC_APPLY_THREADS_PER_BLOCK (32 * 16) -#define THC_APPLY_BLOCKS_PER_SM 4 -template -#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_2(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) -#endif -__global__ void -kernelPointwiseApply1(const OffsetInfo a, - IndexType totalElements, - Op op) { - // NOTE: The two typecasts below are essential when IndexType is 64-bit; - // without them, results are silently truncated to 32 bits! - for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < totalElements; - linearIndex += (IndexType) gridDim.x * blockDim.x) { - op(a.get(linearIndex)); - } -} - -template -#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_2(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) -#endif -__global__ void -kernelPointwiseApply2(const OffsetInfo a, - const OffsetInfo b, - IndexType totalElements, - Op op) { - for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < totalElements; - linearIndex += (IndexType) gridDim.x * blockDim.x) { - op(a.get(linearIndex), b.get(linearIndex)); - } -} - -template -#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ -C10_LAUNCH_BOUNDS_2(THC_APPLY_THREADS_PER_BLOCK, THC_APPLY_BLOCKS_PER_SM) -#endif -__global__ void -kernelPointwiseApply3(const OffsetInfo a, - const OffsetInfo b, - const OffsetInfo c, - IndexType totalElements, - Op op) { - for (IndexType linearIndex = (IndexType) blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < totalElements; - linearIndex += (IndexType) gridDim.x * blockDim.x) { - op(a.get(linearIndex), b.get(linearIndex), c.get(linearIndex)); - } -} - -inline dim3 getApplyBlock() { - return dim3(THC_APPLY_THREADS_PER_BLOCK); -} - -inline bool getApplyGrid(THCState* state, uint64_t totalElements, dim3& grid, int curDevice) { - if (curDevice == -1) return false; - - uint64_t numBlocks = THCCeilDiv(totalElements, static_cast(THC_APPLY_THREADS_PER_BLOCK)); - uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; - if (numBlocks > maxGridX) - numBlocks = maxGridX; - - // For 32-bit indices, make sure that gridDim.x * blockDim.x fits in 32 bits. - if (totalElements <= INT32_MAX && - numBlocks > INT32_MAX / THC_APPLY_THREADS_PER_BLOCK) - numBlocks = INT32_MAX / THC_APPLY_THREADS_PER_BLOCK; - - grid = dim3(numBlocks); - return true; -} - -template -bool THC_pointwiseApply1(THCState* state, - TensorTypeA* a, - const Op& op, - TensorArgType aType = ReadWrite) { - if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS) { - return false; - } - - if (THCTensor_nDimensionLegacyAll(state, a) == 0) { - // Zero-dim tensor; do nothing - return true; - } - - const dim3 block = getApplyBlock(); - - dim3 grid; - ptrdiff_t totalElements = THCTensor_nElement(state, a); - - int curDevice = -1; - cudaGetDevice(&curDevice); - if (!getApplyGrid(state, totalElements, grid, curDevice)) { - return false; - } - - /* - Expands readable/writable tensors whose indices may be "overlapped." - This ensures that each element of the tensor is operated on once and only - once. - */ - TensorTypeA* oldA = NULL; - - if (aType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, a)) { - // Must perform in contiguous space - oldA = a; - a = (TensorTypeA*)THCTensor_newContiguous(state, a); - } - - // It is possible that the tensor dimensions are able to be collapsed, - // and thus we can reduce the actual code complexity of the copy by - // exploiting this knowledge statically, since the div/mod is the - // most expensive part of the operation, more so than memory accesses. - // For instance, when copying a non-contiguous to a contiguous tensor - // (or vice versa), the contiguous tensor can be collapsed to one - // dimension, and the loop to translate the linear index to the array - // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A) \ - kernelPointwiseApply1 \ - <<>>( \ - OffsetInfo(aInfo), (TYPE) totalElements, op); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); - -#define HANDLE_A_CASE(TYPE, A) { \ - switch (A) { \ - case 1: \ - HANDLE_CASE(TYPE, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, 2); \ - break; \ - default: \ - HANDLE_CASE(TYPE, -1); \ - break; \ - } \ -} - - // Can we use 32-bit integer math in the kernel (the linear ID for the copy - // and the resulting non-linear offset is all computable using 32-bit math?) - // We also use unsigned index math in the kernel, as signed div/mod has - // additional overhead. - if (THCTensor_canUse32BitIndexMath(state, a)) { - TensorInfo aInfo = - getTensorInfo(state, a); - rearrangeDims(&aInfo); - aInfo.collapseDims(); -#if CUDA_VERSION < 9000 - if (!aInfo.isContiguous()) { - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); - } -#endif - HANDLE_A_CASE(unsigned int, aInfo.dims); - } else { - TensorInfo aInfo = - getTensorInfo(state, a); - rearrangeDims(&aInfo); - aInfo.collapseDims(); - - /* - Only instantiates the all 1D special case and the fallback all nD case for - large (64-bit indexed) tensors to reduce compilation time. - */ - if (aInfo.dims == 1) { - OffsetInfo - aOffset(aInfo); - kernelPointwiseApply1 - <<>>( - aOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - -#if CUDA_VERSION < 9000 - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); -#endif - OffsetInfo - aOffset(aInfo); - kernelPointwiseApply1 - <<>>( - aOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } -#undef HANDLE_CASE -#undef HANDLE_A_CASE - - if (oldA) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldA contiguous. - THCTensor_copyIgnoringOverlaps(state, oldA, a); - THCTensor_free(state, a); - a = oldA; - } - - return true; -} - -template -bool THC_pointwiseApply2(THCState* state, - TensorTypeA* a, - TensorTypeB* b, - const Op& op, - TensorArgType aType = ReadWrite, - TensorArgType bType = ReadOnly) { - ptrdiff_t totalElements = THCTensor_nElement(state, a); - if (totalElements != THCTensor_nElement(state, b)) { - return false; - } - - if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS || - THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS) { - return false; - } - - if (THCTensor_nDimensionLegacyAll(state, a) == 0) { - // Zero-dim tensor; do nothing - return true; - } - - const dim3 block = getApplyBlock(); - - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - if (!getApplyGrid(state, totalElements, grid, curDevice)) { - return false; - } - - /* - Expands readable/writable tensors whose indices may be "overlapped." - This ensures that each element of the tensor is operated on once and only - once. - */ - TensorTypeA* oldA = NULL; - TensorTypeB* oldB = NULL; - - if (aType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, a)) { - // Must perform in contiguous space - oldA = a; - a = (TensorTypeA*)THCTensor_newContiguous(state, a); - } - if (bType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, b)) { - // Must perform in contiguous space - oldB = b; - b = (TensorTypeB*)THCTensor_newContiguous(state, b); - } - - // It is possible that the tensor dimensions are able to be collapsed, - // and thus we can reduce the actual code complexity of the copy by - // exploiting this knowledge statically, since the div/mod is the - // most expensive part of the operation, more so than memory accesses. - // For instance, when copying a non-contiguous to a contiguous tensor - // (or vice versa), the contiguous tensor can be collapsed to one - // dimension, and the loop to translate the linear index to the array - // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, A, B) \ - kernelPointwiseApply2 \ - <<>>( \ - OffsetInfo(aInfo), \ - OffsetInfo(bInfo), \ - (TYPE) totalElements, op); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - -#define HANDLE_B_CASE(TYPE, A, B) { \ - switch (B) { \ - case 1: \ - HANDLE_CASE(TYPE, A, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, A, 2); \ - break; \ - default: \ - HANDLE_CASE(TYPE, A, -1); \ - break; \ - } \ -} - -#define HANDLE_A_CASE(TYPE, A, B) { \ - switch (A) { \ - case 1: \ - HANDLE_B_CASE(TYPE, 1, B); \ - break; \ - case 2: \ - HANDLE_B_CASE(TYPE, 2, B); \ - break; \ - default: \ - HANDLE_B_CASE(TYPE, -1, B); \ - break; \ - } \ -} - - if (THCTensor_canUse32BitIndexMath(state, a) && - THCTensor_canUse32BitIndexMath(state, b)) { - TensorInfo aInfo = - getTensorInfo(state, a); - - TensorInfo bInfo = - getTensorInfo(state, b); - - rearrangeDims(&aInfo, &bInfo); - aInfo.collapseDims(); - bInfo.collapseDims(); -#if CUDA_VERSION < 9000 - if (!(aInfo.isContiguous() && bInfo.isContiguous())) - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); -#endif - - HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); - } else { - TensorInfo aInfo = - getTensorInfo(state, a); - - TensorInfo bInfo = - getTensorInfo(state, b); - - rearrangeDims(&aInfo, &bInfo); - aInfo.collapseDims(); - bInfo.collapseDims(); - - /* - Only instantiates the all 1D special case and the fallback all nD case for - large (64-bit indexed) tensors to reduce compilation time. - */ - if (aInfo.dims == 1 && bInfo.dims == 1) { - OffsetInfo - aOffset(aInfo); - OffsetInfo - bOffset(bInfo); - kernelPointwiseApply2 - <<>>( - aOffset, bOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { -#if CUDA_VERSION < 9000 - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); -#endif - OffsetInfo - aOffset(aInfo); - OffsetInfo - bOffset(bInfo); - kernelPointwiseApply2 - <<>>( - aOffset, bOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } -#undef HANDLE_CASE -#undef HANDLE_B_CASE -#undef HANDLE_A_CASE - - if (oldA) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldA contiguous. - THCTensor_copyIgnoringOverlaps(state, oldA, a); - THCTensor_free(state, a); - a = oldA; - } - - if (oldB) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldB contiguous. - THCTensor_copyIgnoringOverlaps(state, oldB, b); - THCTensor_free(state, b); - b = oldB; - } - - return true; -} - -template -bool THC_pointwiseApply3(THCState* state, - TensorTypeA* a, - TensorTypeB* b, - TensorTypeC* c, - const Op& op, - TensorArgType aType = ReadWrite, - TensorArgType bType = ReadOnly, - TensorArgType cType = ReadOnly) { - ptrdiff_t totalElements = THCTensor_nElement(state, a); - - if (totalElements != THCTensor_nElement(state, b) || - totalElements != THCTensor_nElement(state, c)) { - return false; - } - - if (THCTensor_nDimensionLegacyAll(state, a) > MAX_CUTORCH_DIMS || - THCTensor_nDimensionLegacyAll(state, b) > MAX_CUTORCH_DIMS || - THCTensor_nDimensionLegacyAll(state, c) > MAX_CUTORCH_DIMS) { - return false; - } - - if (THCTensor_nDimensionLegacyAll(state, a) == 0) { - // Zero-dim tensor; do nothing - return true; - } - - const dim3 block = getApplyBlock(); - - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - if (!getApplyGrid(state, totalElements, grid, curDevice)) { - return false; - } - - /* - Expands readable/writable tensors whose indices may be "overlapped." - This ensures that each element of the tensor is operated on once and only - once. - */ - TensorTypeA* oldA = NULL; - TensorTypeB* oldB = NULL; - TensorTypeC* oldC = NULL; - - if (aType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, a)) { - // Must perform in contiguous space - oldA = a; - a = (TensorTypeA*)THCTensor_newContiguous(state, a); - } - if (bType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, b)) { - // Must perform in contiguous space - oldB = b; - b = (TensorTypeB*)THCTensor_newContiguous(state, b); - } - if (cType == ReadWrite && - THCTensor_maybeOverlappingIndices(state, c)) { - // Must perform in contiguous space - oldC = c; - c = (TensorTypeC*)THCTensor_newContiguous(state, c); - } - -#define HANDLE_CASE(TYPE, A, B, C) \ - kernelPointwiseApply3 \ - <<>>( \ - OffsetInfo \ - (aInfo), \ - OffsetInfo \ - (bInfo), \ - OffsetInfo \ - (cInfo), \ - (TYPE) totalElements, op); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); - -#define HANDLE_C_CASE(TYPE, A, B, C) { \ - switch (C) { \ - case 1: \ - HANDLE_CASE(TYPE, A, B, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, A, B, 2); \ - break; \ - default: \ - HANDLE_CASE(TYPE, A, B, -1); \ - break; \ - } \ -} - -#define HANDLE_B_CASE(TYPE, A, B, C) { \ - switch (B) { \ - case 1: \ - HANDLE_C_CASE(TYPE, A, 1, C); \ - break; \ - case 2: \ - HANDLE_C_CASE(TYPE, A, 2, C); \ - break; \ - default: \ - HANDLE_C_CASE(TYPE, A, -1, C); \ - break; \ - } \ -} - -#define HANDLE_A_CASE(TYPE, A, B, C) { \ - switch (A) { \ - case 1: \ - HANDLE_B_CASE(TYPE, 1, B, C); \ - break; \ - case 2: \ - HANDLE_B_CASE(TYPE, 2, B, C); \ - break; \ - default: \ - HANDLE_B_CASE(TYPE, -1, B, C); \ - break; \ - } \ -} - - if (THCTensor_canUse32BitIndexMath(state, a) && - THCTensor_canUse32BitIndexMath(state, b) && - THCTensor_canUse32BitIndexMath(state, c)) { - TensorInfo aInfo = - getTensorInfo(state, a); - - TensorInfo bInfo = - getTensorInfo(state, b); - - TensorInfo cInfo = - getTensorInfo(state, c); - - rearrangeDims(&aInfo, &bInfo, &cInfo); - aInfo.collapseDims(); - bInfo.collapseDims(); - cInfo.collapseDims(); - -#if CUDA_VERSION < 9000 - if (!(aInfo.isContiguous() && bInfo.isContiguous() && cInfo.isContiguous())) - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); -#endif - HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims); - } else { - TensorInfo aInfo = - getTensorInfo(state, a); - - TensorInfo bInfo = - getTensorInfo(state, b); - - TensorInfo cInfo = - getTensorInfo(state, c); - - rearrangeDims(&aInfo, &bInfo, &cInfo); - aInfo.collapseDims(); - bInfo.collapseDims(); - cInfo.collapseDims(); - - /* - Only instantiates the all 1D special case and the fallback all nD case for - large (64-bit indexed) tensors to reduce compilation time. - */ - if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1) { - OffsetInfo - aOffset(aInfo); - OffsetInfo - bOffset(bInfo); - OffsetInfo - cOffset(cInfo); - kernelPointwiseApply3 - <<>>( - aOffset, bOffset, cOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { -#if CUDA_VERSION < 9000 - grid.x = min(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * THC_APPLY_BLOCKS_PER_SM , grid.x); -#endif - - OffsetInfo - aOffset(aInfo); - OffsetInfo - bOffset(bInfo); - OffsetInfo - cOffset(cInfo); - kernelPointwiseApply3 - <<>>( - aOffset, bOffset, cOffset, (uint64_t) totalElements, op); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } -#undef HANDLE_CASE -#undef HANDLE_C_CASE -#undef HANDLE_B_CASE -#undef HANDLE_A_CASE - - if (oldA) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldA contiguous. - THCTensor_copyIgnoringOverlaps(state, oldA, a); - THCTensor_free(state, a); - a = oldA; - } - - if (oldB) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldB contiguous. - THCTensor_copyIgnoringOverlaps(state, oldB, b); - THCTensor_free(state, b); - b = oldB; - } - - if (oldC) { - // Ignore overlaps when copying back; if we use THCTensor_copy - // instead, it will recursively try and invoke ourselves to make - // oldC contiguous. - THCTensor_copyIgnoringOverlaps(state, oldC, c); - THCTensor_free(state, c); - c = oldC; - } - - return true; -} - -#undef THC_APPLY_THREADS_PER_BLOCK -#undef THC_APPLY_BLOCKS_PER_SM - -#endif // THC_APPLY_INC diff --git a/aten/src/THC/THCTensorCopy.cu b/aten/src/THC/THCTensorCopy.cu index f4db80dfeb86a..fa1df622aff7c 100644 --- a/aten/src/THC/THCTensorCopy.cu +++ b/aten/src/THC/THCTensorCopy.cu @@ -1,35 +1,6 @@ -#include -#include -#include +#include +#include #include -#include -#include - -// Copy operator for the pointwise apply kernel -template -struct CopyOp { - __device__ __forceinline__ void operator()(T* dst, T* src) { -#if __CUDA_ARCH__ >= 350 - *dst = c10::static_cast_with_inter_type::apply(*src); -#else - *dst = c10::static_cast_with_inter_type::apply(*src); -#endif - } -}; - -template <> -struct CopyOp { - __device__ __forceinline__ void operator()(bool* dst, bool* src) { - *dst = ScalarConvert::to(*src); - } -}; - -template <> -struct CopyOp { - __device__ __forceinline__ void operator()(at::BFloat16* dst, at::BFloat16* src) { - *dst = ScalarConvert::to(*src); - } -}; #include #include diff --git a/aten/src/THC/THCTensorMath.cu b/aten/src/THC/THCTensorMath.cu deleted file mode 100644 index 418bfa9e14919..0000000000000 --- a/aten/src/THC/THCTensorMath.cu +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ -#include -#endif -#include - -template -struct TensorFillOp { - TensorFillOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* v) { *v = val; } - - const T val; -}; - -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h deleted file mode 100644 index 422a423959457..0000000000000 --- a/aten/src/THC/THCTensorMath.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef TH_CUDA_TENSOR_MATH_INC -#define TH_CUDA_TENSOR_MATH_INC - -#include -#include - -#include -#include - -#include -#include - -#include -#include - -#include -#include - -#include -#include - -#include -#include - -#endif diff --git a/aten/src/THC/THCTensorMathMagma.cpp b/aten/src/THC/THCTensorMathMagma.cpp index ca0cc8a621282..43607531bd60e 100644 --- a/aten/src/THC/THCTensorMathMagma.cpp +++ b/aten/src/THC/THCTensorMathMagma.cpp @@ -1,23 +1,10 @@ #include -#include -#include -#include -#include -#include -#include -#include #include #ifdef USE_MAGMA #include #endif -#ifndef DIVUP -#define DIVUP(x, y) (((x) + (y) - 1) / (y)) -#endif - -#define NoMagma(name) "No CUDA implementation of '" #name "'. Install MAGMA and rebuild cutorch (http://icl.cs.utk.edu/magma/)" - namespace { void _THCMagma_init() { #ifdef USE_MAGMA @@ -31,6 +18,3 @@ struct Initializer { }; } initializer; } // anonymous namespace - -#include -#include diff --git a/aten/src/THC/THCTensorMathMagma.h b/aten/src/THC/THCTensorMathMagma.h deleted file mode 100644 index 1fb5821afce56..0000000000000 --- a/aten/src/THC/THCTensorMathMagma.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef THC_TENSOR_MATH_MAGMA_CUH -#define THC_TENSOR_MATH_MAGMA_CUH - -#ifdef USE_MAGMA -#include -#endif - -#ifdef USE_MAGMA -template -static inline T* th_magma_malloc_pinned(size_t n) -{ - void* ptr; - if (MAGMA_SUCCESS != magma_malloc_pinned(&ptr, n * sizeof(T))) - THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", n/268435456); - return reinterpret_cast(ptr); -} - -#endif - -#endif // THC_TENSOR_MATH_MAGMA_CUH diff --git a/aten/src/THC/THCTensorMathPairwise.cu b/aten/src/THC/THCTensorMathPairwise.cu deleted file mode 100644 index 6fd026aa8966d..0000000000000 --- a/aten/src/THC/THCTensorMathPairwise.cu +++ /dev/null @@ -1,24 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -template -struct TensorMulConstantOp { - TensorMulConstantOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = *in * val; - } - - __device__ __forceinline__ void operator()(T* v) { - *v *= val; - } - - const T val; -}; - -#include -#include diff --git a/aten/src/THC/THCTensorMathReduce.cu b/aten/src/THC/THCTensorMathReduce.cu deleted file mode 100644 index 1a2c626537156..0000000000000 --- a/aten/src/THC/THCTensorMathReduce.cu +++ /dev/null @@ -1,2 +0,0 @@ -#include -#include diff --git a/aten/src/THC/generic/THCTensorCopy.cu b/aten/src/THC/generic/THCTensorCopy.cu index 3941ef9599206..4301bccc0539b 100644 --- a/aten/src/THC/generic/THCTensorCopy.cu +++ b/aten/src/THC/generic/THCTensorCopy.cu @@ -44,22 +44,4 @@ void THCTensor_freeCopyTo(THCState *state, THCTensor *self, THCTensor THCTensor_free(state, self); } -template <> -void THCTensor_copyIgnoringOverlaps(THCState* state, THCTensor* dst, THCTensor* src) { - // Called when we are copying into an overlapping index `dst`, but - // we don't care which writer wins. Hacky but it works. - // This is itself invoked by pointwiseApply2 / THCTensor_copy in - // case that there are write overlaps. - // FIXME: really, overlapping writes should be illegal/an error in Torch - THC_pointwiseApply2( - state, dst, src, - CopyOp(), - ReadOnly, /* ignore overwrites */ - ReadOnly); -} - -void THCTensor_(copyIgnoringOverlaps)(THCState* state, THCTensor* dst, THCTensor* src) { - THCTensor_copyIgnoringOverlaps(state, dst, src); -} - #endif diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu deleted file mode 100644 index d07a3e3a62cdc..0000000000000 --- a/aten/src/THC/generic/THCTensorMath.cu +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMath.cu" -#else - -#include - -#include -#include - -void THCTensor_(fill)(THCState* state, THCTensor *self_, scalar_t value) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_)); - - if (!THC_pointwiseApply1( - state, self_, TensorFillOp(value))) { - THArgCheck(false, 1, CUTORCH_DIM_WARNING); - } - - THCudaCheck(cudaGetLastError()); -} - -void THCTensor_(zero)(THCState *state, THCTensor *self_) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_)); - if (THCTensor_(isContiguous)(state, self_)) { - THCudaCheck(cudaMemsetAsync(THCTensor_(data)(state, self_), - 0, - sizeof(scalar_t) * THCTensor_(nElement)(state, self_), - c10::cuda::getCurrentCUDAStream())); - } else { - if (!THC_pointwiseApply1( - state, self_, - TensorFillOp(ScalarConvert::to(0)))) { - THArgCheck(false, 1, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - -ptrdiff_t -THCTensor_(numel)(THCState *state, THCTensor *t) -{ - return THCTensor_(nElement)(state, t); -} - -void THCTensor_(check_shape_except_dim)(THCState *state, - THCTensor *first, THCTensor *second, int dimension, int index); -inline void THCTensor_(check_shape_except_dim)(THCState *state, - THCTensor *first, THCTensor *second, int dimension, int index) -{ - int first_dims = first->dim(); - int second_dims = second->dim(); - THArgCheck(first_dims == second_dims, 0, - "Tensors must have same number of dimensions: got %d and %d", - first_dims, second_dims); - for (int dim = 0; dim < first_dims; dim++) { - if (dim == dimension) { - continue; - } - int64_t first_dim_size = THCTensor_(size)(state, first, dim); - int64_t second_dim_size = THCTensor_(size)(state, second, dim); - THArgCheck(first_dim_size == second_dim_size, 0, - "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d (The offending index is %d)", - dimension, (long long)first_dim_size, (long long)second_dim_size, dim, index); - } -} - - -#endif diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h deleted file mode 100644 index 58ec1567aed9b..0000000000000 --- a/aten/src/THC/generic/THCTensorMath.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMath.h" -#else - -TORCH_CUDA_CU_API void THCTensor_( - fill)(THCState* state, THCTensor* self, scalar_t value); -TORCH_CUDA_CU_API void THCTensor_(zero)(THCState* state, THCTensor* self); -TORCH_CUDA_CU_API ptrdiff_t THCTensor_(numel)(THCState* state, THCTensor* t); - -#endif diff --git a/aten/src/THC/generic/THCTensorMathMagma.cpp b/aten/src/THC/generic/THCTensorMathMagma.cpp deleted file mode 100644 index 0d94fc320e53b..0000000000000 --- a/aten/src/THC/generic/THCTensorMathMagma.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.cpp" -#else - -#include - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - -static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, THCTensor *src) -{ - THAssert(src->dim() == 2); - if (self == src && self->stride(0) == 1 && self->stride(1) == self->size(0)) - { - THCTensor_(retain)(state, self); - return self; - } - - if (self == src) - self = THCTensor_(new)(state); - else - THCTensor_(retain)(state, self); - - int64_t size[2] = { src->size(0), src->size(1) }; - int64_t stride[2] = { 1, src->size(0) }; - - THCTensor_(resizeNd)(state, self, 2, size, stride); - THCTensor_(copy)(state, self, src); - return self; -} - -void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) -{ -#ifdef USE_MAGMA - THArgCheck(!a_->is_empty() && a_->dim() == 2, 1, "A should be (non-empty) 2 dimensional"); - THArgCheck(!b_->is_empty() && b_->dim() == 2, 1, "b should be (non-empty) 2 dimensional"); - TORCH_CHECK(a_->size(0) == b_->size(0), "Expected A and b to have same size " - "at dim 0, but A has ", a_->size(0), " rows and B has ", b_->size(0), " rows"); - THArgCheck(a_->size(0) >= a_->size(1), 2, "Expected A with shape (m x n) to have " - "m >= n. The case for m < n is not implemented yet."); - - THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_); - THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_); - scalar_t *a_data = THCTensor_(data)(state, a); - scalar_t *b_data = THCTensor_(data)(state, b); - - int64_t m = a->size(0); - int64_t n = a->size(1); - int64_t nrhs = b->size(1); - scalar_t wkopt; - - int info; - { - at::native::MagmaStreamSyncGuard guard; -#if defined(THC_REAL_IS_FLOAT) - magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); -#else - magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, &wkopt, -1, &info); -#endif - - scalar_t *hwork = th_magma_malloc_pinned((size_t)wkopt); - -#if defined(THC_REAL_IS_FLOAT) - magma_sgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info); -#else - magma_dgels_gpu(MagmaNoTrans, m, n, nrhs, a_data, m, b_data, m, hwork, (int)wkopt, &info); -#endif - - magma_free_pinned(hwork); - } - - if (info != 0) - THError("MAGMA gels : Argument %d : illegal value", -info); - - THCTensor_(freeCopyTo)(state, a, ra_); - THCTensor_(freeCopyTo)(state, b, rb_); -#else - THError(NoMagma(gels)); -#endif -} - -#endif - -#endif diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h deleted file mode 100644 index 585d02ceff7a7..0000000000000 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathMagma.h" -#else - -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - -// MAGMA (i.e. CUDA implementation of LAPACK functions) -TORCH_CUDA_CU_API void THCTensor_(gels)( - THCState* state, - THCTensor* rb_, - THCTensor* ra_, - THCTensor* b_, - THCTensor* a_); - -#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - -#endif diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu deleted file mode 100644 index aba731c725423..0000000000000 --- a/aten/src/THC/generic/THCTensorMathPairwise.cu +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathPairwise.cu" -#else - -#include - -#if !defined(THC_REAL_IS_BOOL) - -void THCTensor_(mul)(THCState *state, THCTensor *self_, THCTensor *src_, scalar_t value) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); - if (self_ == src_) { - if (!THC_pointwiseApply1(state, self_, TensorMulConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src_); - - if (!THC_pointwiseApply2(state, self_, src_, TensorMulConstantOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - -#endif - -#endif diff --git a/aten/src/THC/generic/THCTensorMathPairwise.h b/aten/src/THC/generic/THCTensorMathPairwise.h deleted file mode 100644 index deeafb1291fbd..0000000000000 --- a/aten/src/THC/generic/THCTensorMathPairwise.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THC/generic/THCTensorMathPairwise.h" -#else - -TORCH_CUDA_CU_API int THCTensor_( - equal)(THCState* state, THCTensor* self, THCTensor* src); - -#if !defined(THC_REAL_IS_BOOL) - -TORCH_CUDA_CU_API void THCTensor_( - mul)(THCState* state, THCTensor* self, THCTensor* src, scalar_t value); - -#endif - -#endif diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt deleted file mode 100644 index 55197277b3779..0000000000000 --- a/aten/src/THCUNN/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} -${CMAKE_CURRENT_SOURCE_DIR}/SpatialConvolutionMM.cu -PARENT_SCOPE) - -set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} - "${CMAKE_CURRENT_SOURCE_DIR}" -PARENT_SCOPE) - -install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - DESTINATION ${ATEN_INSTALL_INCLUDE_SUBDIR} - FILES_MATCHING PATTERN "*.h" PATTERN "*.cuh") diff --git a/aten/src/THCUNN/README.md b/aten/src/THCUNN/README.md deleted file mode 100644 index 5c4662322cbb5..0000000000000 --- a/aten/src/THCUNN/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# THCUNN - -THCUNN is a library that gathers nn's C implementations of neural network modules. It's entirely free of Lua dependency and therefore can be used in any application that has a C FFI. Please note that it only contains quite low level functions; most users will want to use ATen, which provides a C++ wrapper around these functions. - - -Looking to add an implementation? Consider writing an ATen native function -instead! See [../ATen/native](../ATen/native). - -## Links - -* [API reference](doc/api_reference.md) -* [Style guidelines](doc/style_guidelines.md) - -## API - -THCUNN is a purely functional library. It provides 2-3 functions for each module, that perform the most important operations: - -* **updateOutput** - applies the module to an input -* **updateGradInput** - accepts gradient w.r.t. output and previous module input, and computes a gradient w.r.t. that input -* **accGradParameters** - *(optional, only modules with parameters)* accepts gradient w.r.t. output and previous module input, and computes gradient w.r.t. the parameters - -For information on argument types please check the [API reference](doc/api_reference.md). - -## Developer docs - -* [Style guidelines](doc/style_guidelines.md) diff --git a/aten/src/THCUNN/SharedMem.cuh b/aten/src/THCUNN/SharedMem.cuh deleted file mode 100644 index 8d83d9f9a9c58..0000000000000 --- a/aten/src/THCUNN/SharedMem.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Based on the simpleTempltes CUDA example - -#ifndef THCUNN_SHAREDMEM_H -#define THCUNN_SHAREDMEM_H - -template -struct SharedMem { - __device__ T *getPointer() - { - extern __device__ void error(void); - error(); - return NULL; - } -}; - -template <> -struct SharedMem -{ - __device__ half *getPointer() { - extern __shared__ half s_half[]; - return s_half; - } -}; - -template <> -struct SharedMem -{ - __device__ float *getPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> -struct SharedMem -{ - __device__ double *getPointer() { - extern __shared__ double s_double[]; - return s_double; - } -}; - -#endif diff --git a/aten/src/THCUNN/SpatialConvolutionMM.cu b/aten/src/THCUNN/SpatialConvolutionMM.cu deleted file mode 100644 index 020bfa1ebf8ce..0000000000000 --- a/aten/src/THCUNN/SpatialConvolutionMM.cu +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include -#include - -#include -#include diff --git a/aten/src/THCUNN/THCHalfAutoNumerics.cuh b/aten/src/THCUNN/THCHalfAutoNumerics.cuh deleted file mode 100644 index 62691b9df7c21..0000000000000 --- a/aten/src/THCUNN/THCHalfAutoNumerics.cuh +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef THC_HALF_AUTO_NUMERICS_INC -#define THC_HALF_AUTO_NUMERICS_INC - -#include -#include - -// WARNING: THCNumerics is being deprecated. Read the comments and function usage -// in THCNumerics to learn about the deprecation -// -// Half numerics functions defined as free functions, so cunn code can be -// written generically, i.e. without excessive calling of THCNumerics functions. - -// these functions should move to THCNumerics - -inline __host__ __device__ THHalf fmaxType(THHalf x, THHalf y) { - return THCNumerics::ge(x, y) ? x : y; -} - -inline __host__ __device__ float fmaxType(float x, THHalf y) { - return fmaxf(x, ScalarConvert::to(y)); -} - -inline __host__ __device__ float fmaxType(float x, float y) { - return fmaxf(x, y); -} - -inline __host__ __device__ double fmaxType(double x, double y) { - return fmax(x, y); -} - - -// arithmetic functions - -inline __host__ __device__ THHalf pow(THHalf a, THHalf b) { - return THCNumerics::pow(a, b); -} - -#endif diff --git a/aten/src/THCUNN/THCUNN.h b/aten/src/THCUNN/THCUNN.h deleted file mode 100644 index a4392ddaba166..0000000000000 --- a/aten/src/THCUNN/THCUNN.h +++ /dev/null @@ -1,13 +0,0 @@ -#include - -#define THCIndexTensor THCudaLongTensor -#define THCIndexTensor_(NAME) THCudaLongTensor_ ## NAME -typedef int64_t THCIndex_t; - -#define THNN_(NAME) TH_CONCAT_3(THNN_, CReal, NAME) - -#include -#include - -#include -#include diff --git a/aten/src/THCUNN/common.h b/aten/src/THCUNN/common.h deleted file mode 100644 index 69b7f3a4d3fa8..0000000000000 --- a/aten/src/THCUNN/common.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef THCUNN_COMMON_H -#define THCUNN_COMMON_H - -#define THCUNN_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \ - "Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one.") - -// Use 1024 threads per block, which requires cuda sm_2x or above -const int CUDA_NUM_THREADS = 1024; - -// CUDA: number of blocks for threads. -inline int GET_BLOCKS(const int64_t N) -{ - // Round up division for positive number - auto block_num = N / CUDA_NUM_THREADS + (N % CUDA_NUM_THREADS == 0 ? 0 : 1); - - constexpr int64_t max_int = std::numeric_limits::max(); - THAssertMsg(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); - - return static_cast(block_num); -} - -#define THCUNN_resizeAs_indices(STATE, I1, I2) \ - if (!I1->sizes().equals(I2->sizes())) \ - { \ - THCudaLongTensor_resizeAs(STATE, I1, I2); \ - } - -#define THCUNN_check_shape(STATE, I1, I2) \ - if (I1 != NULL && I2 != NULL && !THCTensor_(isSameSizeAs)(STATE, I1, I2)) \ - { \ - THCDescBuff s1 = THCTensor_(sizeDesc)(STATE, I1); \ - THCDescBuff s2 = THCTensor_(sizeDesc)(STATE, I2); \ - THError(#I1 " and " #I2 " shapes do not match: " \ - #I1 " %s, " #I2 " %s", s1.str, s2.str); \ - } - - -#define THCUNN_check_shape_indices(STATE, I1, I2) \ - if (!I1->sizes().equals(I2->sizes())) \ - { \ - THCDescBuff s1 = THCIndexTensor_(sizeDesc)(STATE, I1); \ - THCDescBuff s2 = THCTensor_(sizeDesc)(STATE, I2); \ - THError(#I1 " and " #I2 " shapes do not match: " \ - #I1 " %s, " #I2 " %s", s1.str, s2.str); \ - } - -#define THCUNN_check_nElement(STATE, I1, I2) \ - if (I1 != NULL && I2 != NULL ) { \ - ptrdiff_t n1 = THCTensor_(nElement)(STATE, I1); \ - ptrdiff_t n2 = THCTensor_(nElement)(STATE, I2); \ - if (n1 != n2) \ - { \ - THCDescBuff s1 = THCTensor_(sizeDesc)(state, I1); \ - THCDescBuff s2 = THCTensor_(sizeDesc)(state, I2); \ - THError(#I1 " and " #I2 " have different number of elements: " \ - #I1 "%s has %ld elements, while " \ - #I2 "%s has %ld elements", s1.str, n1, s2.str, n2); \ - } \ - } - -#define THCUNN_check_dim_size(STATE, T, DIM, DIM_SIZE, SIZE) \ - if (THCTensor_(nDimensionLegacyNoScalars)(STATE, T) != DIM || \ - THCTensor_(sizeLegacyNoScalars)(STATE, T, DIM_SIZE) != SIZE) { \ - THCDescBuff s1 = THCTensor_(sizeDesc)(state, T); \ - THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \ - " but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \ - } - -#define THCUNN_check_dim_size_indices(STATE, T, DIM, DIM_SIZE, SIZE) \ - if (THCIndexTensor_(nDimensionLegacyNoScalars)(STATE, T) != DIM || \ - THCIndexTensor_(sizeLegacyNoScalars)(STATE, T, DIM_SIZE) != SIZE) { \ - THCDescBuff s1 = THCIndexTensor_(sizeDesc)(state, T); \ - THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \ - " but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \ - } - -#define THCUNN_argCheck(STATE, COND, ARG, T, FORMAT) \ - if (!(COND)) { \ - THCDescBuff s1 = THCTensor_(sizeDesc)(state, T); \ - THArgCheck(COND, ARG, FORMAT, s1.str); \ - } - -#endif diff --git a/aten/src/THCUNN/doc/api_reference.md b/aten/src/THCUNN/doc/api_reference.md deleted file mode 100644 index 3f49b9b6d1ce6..0000000000000 --- a/aten/src/THCUNN/doc/api_reference.md +++ /dev/null @@ -1,26 +0,0 @@ -# API docs - -This document describes the conventions behind the THCUNN API. - -### The API - -All functions provided by THCUNN are stored in `aten/src/THCUNN/generic/THCUNN.h`. -Look at this file. - -### Note on function names - -Please remember, that because C doesn't support function overloading, functions taking different tensor types have different names. So e.g. for an Abs module, there are actually two updateOutput functions: - -* `void THNN_FloatAbs_updateOutput(...)` -* `void THNN_DoubleAbs_updateOutput(...)` - -In these docs such function will be referred to as `void THCUNN_Abs_updateOutput(...)`, and it's up to developer to add a type prefix. `real` is an alias for that type. - -### Argument types - -Some arguments have additional tags placed in square brackets in their header declarations: - -* **[OUT]** - This is the output argument. It will be reshaped if needed. -* **[OPTIONAL]** - This argument is optional and can be safely set to NULL -* **[BUFFER]** - A buffer. `updateGradInput` and `accGradParameters` should get the same buffers that were used in `updateOutput` call. -* **[MODIFIED]** - Some functions accept an `inplace` flag. If set to true, this argument might be modified (in addition to the output). diff --git a/aten/src/THCUNN/doc/style_guidelines.md b/aten/src/THCUNN/doc/style_guidelines.md deleted file mode 100644 index 086db8bcbe28a..0000000000000 --- a/aten/src/THCUNN/doc/style_guidelines.md +++ /dev/null @@ -1,64 +0,0 @@ -## API design guidelines - -Functions should return `void`. - -All functions should accept arguments in the following order. `...` represent any module-specific parameters or buffers, disregarding whether they are used for writing or reading. Arguments in `...` below should be ordered like this: -``` -[weight], [bias], [any buffers], [additional arguments], [optional arguments] -``` - -### Modules -``` -updateOutput: state, input, output, ... -updateGradInput: state, input, gradOutput, gradInput, ... -accGradParameters: state, input, gradOutput, [gradWeight], [gradBias], ... -``` - -e.g. -```C -void THNN_(ClassNLLCriterion_updateGradInput)( - THCState *state, - THCTensor *input, - THCIndexTensor *target, - THCTensor *gradOutput, - THCTensor *gradInput, - int64_t reduction, - THCTensor *weights, - THCTensor *total_weight, - int64_t ignore_index) -``` - -### Criterions -``` -updateOutput: state, input, target, output, ... -updateGradInput: state, input, target, gradInput, ... -``` - -e.g. - -```C -void THNN_(ClassNLLCriterion_updateOutput)( - THCState *state, - THCTensor *input, - THCIndexTensor *target, - THCTensor *output, - int64_t reduction, - THCTensor *weights, - THCTensor *total_weight, - int64_t ignore_index) -``` - -## Code style guide - -```C -void THNN_(GatedLinear_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - int dim) -//<- 10 -> -``` - -All arguments should start on a new line after function name, and they should be indented using 10 spaces. - -Use 2 spaces for block indentation. diff --git a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu deleted file mode 100644 index af492b3e7da02..0000000000000 --- a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu +++ /dev/null @@ -1,499 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/SpatialConvolutionMM.cu" -#else - -#include -#include - -static inline void THNN_(SpatialConvolutionMM_shapeCheck)( - THCState *state, - THCTensor *input, THCTensor *gradOutput, - THCTensor *weight, THCTensor *bias, - int kH, int kW, int dH, int dW, int padH, int padW, - int weight_nullable) { - THArgCheck(kW > 0 && kH > 0, 9, - "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - THArgCheck(dW > 0 && dH > 0, 11, - "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - - if (weight != NULL) { - THCUNN_argCheck(state, !weight->is_empty() && (weight->dim() == 2 || weight->dim() == 4), 5, weight, - "non-empty 2D or 4D weight tensor expected, but got: %s"); - if (bias != NULL) { - THCUNN_check_dim_size(state, bias, 1, 0, weight->size(0)); - } - } else if (!weight_nullable) { - THError("weight tensor is expected to be non-nullable"); - } - - int ndim = input->dim(); - int dimf = 0; - int dimh = 1; - int dimw = 2; - - if (ndim == 4) { - dimf++; - dimh++; - dimw++; - } - - // Allow for empty batch size but not other dimensions - bool valid_empty = false; - if (ndim == 3) { - valid_empty = input->size(0) == 0 && input->size(1) != 0 && input->size(2) != 0; - } else if (ndim == 4) { - valid_empty = input->size(0) == 0 && input->size(1) != 0 && input->size(2) != 0 && input->size(3) != 0; - } - - - THCUNN_argCheck(state, (!input->is_empty() || valid_empty) && (ndim == 3 || ndim == 4), 2, input, - "non-empty 3D or 4D input tensor expected but got: %s"); - - int64_t inputHeight = input->size(dimh); - int64_t inputWidth = input->size(dimw); - - int64_t exactInputHeight = inputHeight + 2 * padH; - int64_t exactInputWidth = inputWidth + 2 * padW; - - if (exactInputHeight < kH || exactInputWidth < kW) { - THError("Calculated padded input size per channel: (%ld x %ld). " - "Kernel size: (%d x %d). Kernel size can't be greater than actual input size", - exactInputHeight, exactInputWidth, kH, kW); - } - - int64_t outputHeight = div_rtn(exactInputHeight - kH, dH) + 1; - int64_t outputWidth = div_rtn(exactInputWidth - kW, dW) + 1; - - if (outputWidth < 1 || outputHeight < 1) { - THError("Given input size per channel: (%ld x %ld). " - "Calculated output size per channel: (%ld x %ld). Output size is too small", - inputHeight, inputWidth, outputHeight, outputWidth); - } - - if (weight != NULL) { - int64_t nInputPlane = weight->size(1); - if (weight->dim() == 2) { - nInputPlane /= (kH * kW); - } - THCUNN_check_dim_size(state, input, ndim, dimf, nInputPlane); - } - - if (gradOutput != NULL) { - if (weight != NULL) { - int64_t nOutputPlane = weight->size(0); - THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane); - } else if (bias != NULL) { - int64_t nOutputPlane = bias->dim() == 0 ? 1 : bias->size(0); - THCUNN_check_dim_size(state, gradOutput, ndim, dimf, nOutputPlane); - } - THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight); - THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth); - } -} - -static THCTensor* THNN_(newViewWeightMM2d)(THCState *state, THCTensor *weight) { - weight = THCTensor_(newContiguous)(state, weight); - if (weight->dim() == 4) { - int64_t s1 = weight->size(0); - int64_t s2 = weight->size(1) * weight->size(2) * weight->size(3); - THCTensor *old_weight = weight; - weight = THTensor_wrap(weight).view({s1, s2}).unsafeReleaseTensorImpl(); - THCTensor_(free)(state, old_weight); - } - return weight; -} - -void THNN_(SpatialConvolutionMM_updateOutput)( - THCState *state, - THCTensor *input, - THCTensor *output, - THCTensor *weight, - THCTensor *bias, - THCTensor *columns, - THCTensor *ones, - int kW, int kH, - int dW, int dH, - int padW, int padH) { - THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones); - if (bias) { - THCUNN_assertSameGPU(state, 2, weight, bias); - } - weight = THNN_(newViewWeightMM2d)(state, weight); - THNN_(SpatialConvolutionMM_shapeCheck) - (state, input, NULL, weight, bias, kH, kW, dH, dW, padH, padW, 0); - THArgCheck(!bias || THCTensor_(isContiguous)(state, bias), 5, - "bias tensor has to be contiguous"); - - int ndim = input->dim(); - int dimf = 0; - int dimh = 1; - int dimw = 2; - - if (ndim == 4) { - dimf++; - dimh++; - dimw++; - } - - int64_t nInputPlane = input->size(dimf); - int64_t inputHeight = input->size(dimh); - int64_t inputWidth = input->size(dimw); - int64_t nOutputPlane = weight->size(0); - int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; - int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; - - - input = THCTensor_(newContiguous)(state, input); - int is_batch = 1; - if (input->dim() == 3) { - // Force batch - is_batch = 0; - THCTensor_(resize4d)(state, input, 1, input->size(0), input->size(1), input->size(2)); - } - - // Batch size + input planes - int64_t batchSize = input->size(0); - - // Resize output - THCTensor_(resize4d)(state, output, batchSize, nOutputPlane, outputHeight, outputWidth); - - // Resize temporary columns - THCTensor_(resize2d)(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth); - - // Define a buffer of ones, for bias accumulation - // Note: this buffer can be shared with other modules, it only ever gets increased, - // and always contains ones. - if (bias) { - if (ones->dim() != 2 || ones->size(0)*ones->size(1) < outputHeight*outputWidth) { - // Resize plane and fill with ones... - THCTensor_(resize2d)(state, ones, outputHeight, outputWidth); - THCTensor_(fill)(state, ones, ScalarConvert::to(1)); - } - } - - // Helpers - THCTensor *input_n = THCTensor_(new)(state); - THCTensor *output_n = THCTensor_(new)(state); - - // For each elt in batch, do: - for (int elt = 0; elt < batchSize; elt ++) { - // Matrix mulitply per output: - THCTensor_(select)(state, input_n, input, 0, elt); - THCTensor_(select)(state, output_n, output, 0, elt); - - // Do Bias first: - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m_ = nOutputPlane; - int64_t n_ = outputHeight * outputWidth; - int64_t k_ = 1; - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - if (bias) { - at::cuda::blas::gemm( - 't', 'n', - n_, m_, k_, - ScalarConvert::to(1), - THCTensor_(data)(state, ones), k_, - THCTensor_(data)(state, bias), k_, - ScalarConvert::to(0), - THCTensor_(data)(state, output_n), n_ - ); - } else { - THCTensor_(zero)(state, output_n); - } - - if (kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0) { - // Extract columns: - at::native::im2col( - c10::cuda::getCurrentCUDAStream(), - THCTensor_(data)(state, input_n), - nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, padH, padW, dH, dW, - 1, 1, - columns->data() - ); - } - - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m = nOutputPlane; - int64_t n = columns->size(1); - int64_t k = nInputPlane*kH*kW; - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - auto gemm_in_ptr = - (kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0) - ? THCTensor_(data)(state, columns) - : THCTensor_(data)(state, input_n); - at::cuda::blas::gemm( - 'n', 'n', - n, m, k, - ScalarConvert::to(1), - gemm_in_ptr, n, - THCTensor_(data)(state, weight), k, - ScalarConvert::to(1), - THCTensor_(data)(state, output_n), n - ); - } - - // Free - THCTensor_(free)(state, input_n); - THCTensor_(free)(state, output_n); - - // Resize output - if (is_batch == 0) { - THCTensor_(resize3d)(state, output, nOutputPlane, outputHeight, outputWidth); - THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); - } - - THCTensor_(free)(state, input); - THCTensor_(free)(state, weight); -} - -void THNN_(SpatialConvolutionMM_updateGradInput)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradInput, - THCTensor *weight, - THCTensor *gradColumns, - THCTensor *ones, - int kW, int kH, - int dW, int dH, - int padW, int padH) { - THCUNN_assertSameGPU(state, 5, input, gradOutput, weight, - gradColumns, gradInput); - weight = THNN_(newViewWeightMM2d)(state, weight); - - THNN_(SpatialConvolutionMM_shapeCheck) - (state, input, gradOutput, weight, NULL, kH, kW, dH, dW, padH, padW, 0); - - // Params - int nInputPlane = weight->dim() == 2 ? weight->size(1)/(kW*kH) : weight->size(1); - int nOutputPlane = weight->size(0); - - input = THCTensor_(newContiguous)(state, input); - gradOutput = THCTensor_(newContiguous)(state, gradOutput); - - int is_batch = 1; - if (input->dim() == 3) { - // Force batch - is_batch = 0; - THCTensor_(resize4d)(state, input, 1, input->size(0), input->size(1), input->size(2)); - THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size(0), gradOutput->size(1), gradOutput->size(2)); - } - - int64_t inputWidth = input->size(3); - int64_t inputHeight = input->size(2); - int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; - int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; - - // Batch size + input planes - int64_t batchSize = input->size(0); - - // Resize output - THCTensor_(resize4d)(state, gradInput, batchSize, nInputPlane, inputHeight, inputWidth); - - // Resize temporary columns - THCTensor_(resize2d)(state, gradColumns, nInputPlane*kW*kH, outputHeight*outputWidth); - - // Helpers - THCTensor *gradInput_n = THCTensor_(new)(state); - THCTensor *gradOutput_n = THCTensor_(new)(state); - - // For each elt in batch, do: - for (int elt = 0; elt < batchSize; elt ++) { - // Matrix mulitply per sample: - THCTensor_(select)(state, gradInput_n, gradInput, 0, elt); - THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); - - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m = nInputPlane*kW*kH; - int64_t n = gradColumns->size(1); - int64_t k = nOutputPlane; - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - at::cuda::blas::gemm( - 'n', 't', - n, m, k, - ScalarConvert::to(1), - THCTensor_(data)(state, gradOutput_n), n, - THCTensor_(data)(state, weight), m, - ScalarConvert::to(0), - THCTensor_(data)(state, gradColumns), n - ); - - // Unpack columns back into input: - at::native::col2im( - c10::cuda::getCurrentCUDAStream(), - THCTensor_(data)(state, gradColumns), - nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, - 1, 1, THCTensor_(data)(state, gradInput_n) - ); - } - - // Free - THCTensor_(free)(state, gradInput_n); - THCTensor_(free)(state, gradOutput_n); - THCTensor_(free)(state, weight); - - // Resize output - if (is_batch == 0) { - THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth); - THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); - THCTensor_(resize3d)(state, gradInput, nInputPlane, inputHeight, inputWidth); - } - - THCTensor_(free)(state, input); - THCTensor_(free)(state, gradOutput); -} - -void THNN_(SpatialConvolutionMM_accGradParameters)( - THCState *state, - THCTensor *input, - THCTensor *gradOutput, - THCTensor *gradWeight, - THCTensor *gradBias, - THCTensor *columns, - THCTensor *ones, - int kW, int kH, - int dW, int dH, - int padW, int padH, - accreal scale_) { - scalar_t scale = ScalarConvert::to(scale_); - THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, gradBias, columns, ones); - if (gradWeight) { - THArgCheck(THCTensor_(isContiguous)(state, gradWeight), 4, "gradWeight needs to be contiguous"); - gradWeight = THNN_(newViewWeightMM2d)(state, gradWeight); - } - if (gradBias) { - THArgCheck(THCTensor_(isContiguous)(state, gradBias), 5, "gradBias needs to be contiguous"); - THArgCheck(THCTensor_(isContiguous)(state, ones), 7, "ones needs to be contiguous"); - } - - THNN_(SpatialConvolutionMM_shapeCheck) - (state, input, gradOutput, gradWeight, gradBias, kH, kW, dH, dW, padH, padW, 1); - - // Params - input = THCTensor_(newContiguous)(state, input); - gradOutput = THCTensor_(newContiguous)(state, gradOutput); - - int is_batch = 1; - if (input->dim() == 3) { - // Force batch - is_batch = 0; - THCTensor_(resize4d)(state, input, 1, input->size(0), input->size(1), input->size(2)); - THCTensor_(resize4d)(state, gradOutput, 1, gradOutput->size(0), gradOutput->size(1), gradOutput->size(2)); - } - - int64_t nInputPlane = input->size(1); - int64_t nOutputPlane = gradOutput->size(1); - - int64_t inputWidth = input->size(3); - int64_t inputHeight = input->size(2); - int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; - int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; - - // Batch size + input planes - int64_t batchSize = input->size(0); - - // Define a buffer of ones, for bias accumulation - if (ones->dim() != 2 || ones->size(0)*ones->size(1) < outputHeight*outputWidth) { - // Resize plane and fill with ones... - THCTensor_(resize2d)(state, ones, outputHeight, outputWidth); - THCTensor_(fill)(state, ones, ScalarConvert::to(1)); - } - - // Resize temporary columns - THCTensor_(resize2d)(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth); - - // Helpers - THCTensor *input_n = THCTensor_(new)(state); - THCTensor *gradOutput_n = THCTensor_(new)(state); - - // For each elt in batch, do: - for (int elt = 0; elt < batchSize; elt ++) { - // Matrix mulitply per output: - THCTensor_(select)(state, gradOutput_n, gradOutput, 0, elt); - - // Do Weight: - if (gradWeight) { - // Matrix mulitply per output: - THCTensor_(select)(state, input_n, input, 0, elt); - - if (kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0) { - // Extract columns: - at::native::im2col( - c10::cuda::getCurrentCUDAStream(), - THCTensor_(data)(state, input_n), - nInputPlane, inputHeight, inputWidth, - outputHeight, outputWidth, - kH, kW, padH, padW, dH, dW, - 1, 1, - columns->data() - ); - } - - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m = nOutputPlane; - int64_t n = nInputPlane*kW*kH; - int64_t k = columns->size(1); - - // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) - auto gemm_in_ptr = - (kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0) - ? THCTensor_(data)(state, columns) - : THCTensor_(data)(state, input_n); - at::cuda::blas::gemm( - 't', 'n', - n, m, k, - scale, - gemm_in_ptr, k, - THCTensor_(data)(state, gradOutput_n), k, - ScalarConvert::to(1), - THCTensor_(data)(state, gradWeight), n - ); - } - - // Do Bias: - if (gradBias) { - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - int64_t m_ = nOutputPlane; - int64_t k_ = outputHeight * outputWidth; - - // Do GEMV (note: this is a bit confusing because gemv assumes column-major matrices) - //#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16) - at::cuda::blas::gemv( - 't', - k_, m_, - scale, - THCTensor_(data)(state, gradOutput_n), k_, - THCTensor_(data)(state, ones), 1, - ScalarConvert::to(1), - THCTensor_(data)(state, gradBias), 1 - ); - } - } - - // Free - THCTensor_(free)(state, input_n); - THCTensor_(free)(state, gradOutput_n); - if (gradWeight) - THCTensor_(free)(state, gradWeight); - - // Resize - if (is_batch == 0) { - THCTensor_(resize3d)(state, gradOutput, nOutputPlane, outputHeight, outputWidth); - THCTensor_(resize3d)(state, input, nInputPlane, inputHeight, inputWidth); - } - - THCTensor_(free)(state, input); - THCTensor_(free)(state, gradOutput); -} - -#endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h deleted file mode 100644 index 87a6105293057..0000000000000 --- a/aten/src/THCUNN/generic/THCUNN.h +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "THCUNN/generic/THCUNN.h" -#else - -#include -#include - -TORCH_CUDA_CU_API void THNN_(MultiMarginCriterion_updateOutput)( - THCState* state, - THCTensor* input, - THCIndexTensor* target, - THCTensor* output, - int64_t reduction, - int p, - THCTensor* weights, // [OPTIONAL] - accreal margin); - -TORCH_CUDA_CU_API void THNN_(MultiMarginCriterion_updateGradInput)( - THCState* state, - THCTensor* input, - THCIndexTensor* target, - THCTensor* gradOutput, - THCTensor* gradInput, - int64_t reduction, - int p, - THCTensor* weights, // [OPTIONAL] - accreal margin); - -TORCH_CUDA_CU_API void THNN_(SpatialConvolutionMM_updateOutput)( - THCState* state, - THCTensor* input, - THCTensor* output, - THCTensor* weight, - THCTensor* bias, // [OPTIONAL] - THCTensor* columns, - THCTensor* ones, - int kW, - int kH, - int dW, - int dH, - int padW, - int padH); - -TORCH_CUDA_CU_API void THNN_(SpatialConvolutionMM_updateGradInput)( - THCState* state, - THCTensor* input, - THCTensor* gradOutput, - THCTensor* gradInput, - THCTensor* weight, - THCTensor* columns, - THCTensor* ones, - int kW, - int kH, - int dW, - int dH, - int padW, - int padH); - -TORCH_CUDA_CU_API void THNN_(SpatialConvolutionMM_accGradParameters)( - THCState* state, - THCTensor* input, - THCTensor* gradOutput, - THCTensor* gradWeight, - THCTensor* gradBias, // [OPTIONAL] - THCTensor* columns, - THCTensor* ones, - int kW, - int kH, - int dW, - int dH, - int padW, - int padH, - accreal scale); - -#endif diff --git a/benchmarks/cpp/tensorexpr/CMakeLists.txt b/benchmarks/cpp/tensorexpr/CMakeLists.txt index 789c81fcf6526..a06502eb29053 100644 --- a/benchmarks/cpp/tensorexpr/CMakeLists.txt +++ b/benchmarks/cpp/tensorexpr/CMakeLists.txt @@ -6,6 +6,7 @@ add_executable( bench_batchnorm.cpp bench_concat.cpp bench_compile.cpp + bench_signed_log1p.cpp bench_fuser_overhead.cpp bench_gemm.cpp bench_parallel.cpp diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp index 1f09b1dbac5c1..425d19faabc30 100644 --- a/benchmarks/cpp/tensorexpr/bench_approx.cpp +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -10,30 +10,29 @@ using namespace torch::jit; using namespace torch::jit::tensorexpr; -void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target, int width) { +void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor target, int width) { auto loops = ln->getLoopStmtsFor(target); - For *inner, *tail; + ForPtr inner, tail; ln->splitWithTail(loops[0], width, &inner, &tail); ln->vectorize(inner); } -void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) { - std::vector loops = ln->getLoopStmtsFor(target); - For *inner, *tail; +void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor target) { + std::vector loops = ln->getLoopStmtsFor(target); + ForPtr inner, tail; ln->splitWithTail(loops[0], 16 * 8, &inner, &tail); - For* outer = loops[0]; + ForPtr outer = loops[0]; ln->vectorize(inner); ln->splitWithTail(outer, 8, &inner, &tail); - Stmt* unrolled; + StmtPtr unrolled; LoopNest::unroll(inner, &unrolled); } static void relu_nnc(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); auto clamp = 0; - torch::jit::tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i){ + torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i){ auto A_elem = [&]() { auto elem = A.load(i); auto min = FloatImm::make(clamp); @@ -44,7 +43,7 @@ static void relu_nnc(benchmark::State& state) { LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -64,17 +63,16 @@ static void relu_nnc(benchmark::State& state) { } static void log_nnc_sleef(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); - torch::jit::tensorexpr::Tensor* B = + torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { return log(A.load(i)); }); LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 8); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -94,17 +92,16 @@ static void log_nnc_sleef(benchmark::State& state) { } static void log_nnc_fast(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); - torch::jit::tensorexpr::Tensor* B = + torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { return fast_log(A.load(i)); }); LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -124,17 +121,16 @@ static void log_nnc_fast(benchmark::State& state) { } static void log_nnc_vml(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); - torch::jit::tensorexpr::Tensor* B = + torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { return log_vml(A.load(i)); }); LoopNest ln({B}); vectorize(&ln, B, 8); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -164,11 +160,10 @@ static void log_aten(benchmark::State& state) { } static void logit_nnc_sleef(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); auto clamp = 1e-6f; - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { + tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { auto A_elem = [&]() { auto elem = A.load(i); auto min = FloatImm::make(clamp); @@ -181,7 +176,7 @@ static void logit_nnc_sleef(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); optimizePointwise(&ln, B); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -201,11 +196,10 @@ static void logit_nnc_sleef(benchmark::State& state) { } static void logit_nnc_fast(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); auto clamp = 1e-6f; - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { + tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { auto A_elem = [&]() { auto elem = A.load(i); auto min = FloatImm::make(clamp); @@ -218,7 +212,7 @@ static void logit_nnc_fast(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); optimizePointwise(&ln, B); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -238,11 +232,10 @@ static void logit_nnc_fast(benchmark::State& state) { } static void logit_nnc_vml(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); auto clamp = 1e-6f; - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { + tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { auto A_elem = [&]() { auto elem = A.load(i); auto min = FloatImm::make(clamp); @@ -255,7 +248,7 @@ static void logit_nnc_vml(benchmark::State& state) { LoopNest ln({B}); ln.prepareForCodegen(); vectorize(&ln, B, 16); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); @@ -316,17 +309,16 @@ static void logit_caffe2(benchmark::State& state) { } static void tanh_nnc_fast(benchmark::State& state) { - KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); - torch::jit::tensorexpr::Tensor* B = + torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { return fast_tanh(A.load(i)); }); LoopNest ln({B}); optimizePointwise(&ln, B); ln.prepareForCodegen(); - Stmt* s = ln.root_stmt(); + StmtPtr s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; args.emplace_back(B); diff --git a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp index 434cd6bfdbb8e..702ed1cf3ab9d 100644 --- a/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_batchnorm.cpp @@ -74,7 +74,6 @@ BENCHMARK_DEFINE_F(BatchNorm, ATen)(benchmark::State& state) { } BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) { - KernelScope ks; Placeholder input("input", kFloat, {N_, C_, H_, W_}); Placeholder weight("weight", kFloat, {C_}); @@ -84,7 +83,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) { VarHandle eps("eps", kFloat); using axis = const VarHandle&; - Tensor* output = Compute( + Tensor output = Compute( "output", {{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}}, [&](axis n, axis c, axis h, axis w) { @@ -105,7 +104,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNC)(benchmark::State& state) { loops = nest.getLoopStmtsFor(output); loops[0]->set_parallel(); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps}); std::vector args; @@ -137,7 +136,6 @@ BENCHMARK_DEFINE_F(BatchNorm, ATenRelu)(benchmark::State& state) { } BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) { - KernelScope ks; Placeholder input("input", kFloat, {N_, C_, H_, W_}); Placeholder weight("weight", kFloat, {C_}); @@ -147,7 +145,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) { VarHandle eps("eps", kFloat); using axis = const VarHandle&; - Tensor* output = Compute( + Tensor output = Compute( "output", {{N_, "N"}, {C_, "C"}, {H_, "H"}, {W_, "W"}}, [&](axis n, axis c, axis h, axis w) { @@ -163,7 +161,7 @@ BENCHMARK_DEFINE_F(BatchNorm, NNCRelu)(benchmark::State& state) { }); LoopNest nest({output}); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); LLVMCodeGen cg(s, {input, weight, bias, mean, var, output, eps}); std::vector args; diff --git a/benchmarks/cpp/tensorexpr/bench_compile.cpp b/benchmarks/cpp/tensorexpr/bench_compile.cpp index cc84e65a545b2..f204377ab8126 100644 --- a/benchmarks/cpp/tensorexpr/bench_compile.cpp +++ b/benchmarks/cpp/tensorexpr/bench_compile.cpp @@ -10,60 +10,58 @@ namespace te = torch::jit::tensorexpr; static void BM_CompileSwish(benchmark::State& state) { for (auto _ : state) { constexpr int N = 512; - te::KernelScope ks; te::VarHandle n("n", te::kInt); te::Placeholder A(te::BufHandle("A", {N}, te::kFloat)); - te::Tensor* relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { + te::Tensor relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { return te::Max::make(A.load(i), 0.f, false); }); - te::Tensor* min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { - return te::Min::make(relu->load(i), 6.f, false); + te::Tensor min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Min::make(relu.load(i), 6.f, false); }); - te::Tensor* plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { - return min6->load(i) + 3.f; + te::Tensor plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { + return min6.load(i) + 3.f; }); - te::Tensor* times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { - return A.load(i) * plus3->load(i); + te::Tensor times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { + return A.load(i) * plus3.load(i); }); - te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { - return times->load(i) * 1.f / 6.f; + te::Tensor sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { + return times.load(i) * 1.f / 6.f; }); te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth}); for (auto tensor : {relu, min6, plus3, times}) { - nest.computeInline(tensor->buf()); + nest.computeInline(tensor.buf()); } nest.prepareForCodegen(); - te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt()); te::LLVMCodeGen cg(s, {A, sixth, n}); } } static void BM_CompileSwishLLVMOnly(benchmark::State& state) { constexpr int N = 512; - te::KernelScope ks; te::VarHandle n("n", te::kInt); te::Placeholder A(te::BufHandle("A", {N}, te::kFloat)); - te::Tensor* relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { + te::Tensor relu = te::Compute("relu", {{n, "n"}}, [&](const te::VarHandle& i) { return te::Max::make(A.load(i), 0.f, false); }); - te::Tensor* min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { - return te::Min::make(relu->load(i), 6.f, false); + te::Tensor min6 = te::Compute("min6", {{n, "n"}}, [&](const te::VarHandle& i) { + return te::Min::make(relu.load(i), 6.f, false); }); - te::Tensor* plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { - return min6->load(i) + 3.f; + te::Tensor plus3 = te::Compute("plus3", {{n, "n"}}, [&](const te::VarHandle& i) { + return min6.load(i) + 3.f; }); - te::Tensor* times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { - return A.load(i) * plus3->load(i); + te::Tensor times = te::Compute("times", {{n, "n"}}, [&](const te::VarHandle& i) { + return A.load(i) * plus3.load(i); }); - te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { - return times->load(i) * 1.f / 6.f; + te::Tensor sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) { + return times.load(i) * 1.f / 6.f; }); te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth}); for (auto tensor : {relu, min6, plus3, times}) { - nest.computeInline(tensor->buf()); + nest.computeInline(tensor.buf()); } nest.prepareForCodegen(); - te::Stmt* s = te::IRSimplifier::simplify(nest.root_stmt()); + te::StmtPtr s = te::IRSimplifier::simplify(nest.root_stmt()); for (auto _ : state) { te::LLVMCodeGen cg(s, {A, sixth, n}); } diff --git a/benchmarks/cpp/tensorexpr/bench_concat.cpp b/benchmarks/cpp/tensorexpr/bench_concat.cpp index a437967a09497..c108c867acbf4 100644 --- a/benchmarks/cpp/tensorexpr/bench_concat.cpp +++ b/benchmarks/cpp/tensorexpr/bench_concat.cpp @@ -47,7 +47,6 @@ class ConcatBench : public benchmark::Fixture { } void runNNC(benchmark::State& state) { - KernelScope ks; size_t num_inputs = inputs_.size(); size_t num_dims = 2; @@ -60,7 +59,7 @@ class ConcatBench : public benchmark::Fixture { {input_sizes_[i][0], input_sizes_[i][1]})); } - Tensor* output = Compute( + Tensor output = Compute( "aten_cat", {{output_size_[0], "M"}, {output_size_[1], "N"}}, [&](const VarHandle& m, const VarHandle& n) { @@ -83,7 +82,7 @@ class ConcatBench : public benchmark::Fixture { }); LoopNest nest({output}); nest.prepareForCodegen(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); @@ -101,54 +100,57 @@ class ConcatBench : public benchmark::Fixture { } void runNNCLoop(benchmark::State& state) { - KernelScope ks; size_t num_inputs = inputs_.size(); size_t num_dims = 2; TORCH_INTERNAL_ASSERT(concat_dim_ == 1); - auto output_buf = new Buf( - new Var("aten_cat", kHandle), - {new IntImm(output_size_[0]), new IntImm(output_size_[1])}, + auto output_buf = alloc( + alloc("aten_cat", kHandle), + std::vector( + {alloc(output_size_[0]), alloc(output_size_[1])}), kFloat); std::vector inputs; - std::vector for_stmts(num_inputs); + std::vector for_stmts(num_inputs); int cumulative_input_sizes = 0; for (size_t i = 0; i < num_inputs; ++i) { inputs.emplace_back(Placeholder( "input" + std::to_string(i), kFloat, {input_sizes_[i][0], input_sizes_[i][1]})); - std::vector for_vars(num_inputs); + std::vector for_vars(num_inputs); for (size_t d = 0; d < num_dims; ++d) { for_vars[d] = - new Var("i" + std::to_string(i) + "_" + std::to_string(d), kInt); + alloc("i" + std::to_string(i) + "_" + std::to_string(d), kInt); } - auto store = new Store( + auto store = alloc( output_buf, - {for_vars[0], - new Add(for_vars[1], new IntImm(cumulative_input_sizes))}, - new Load(inputs[i].data(), {for_vars[0], for_vars[1]})); - auto for_st = new For( + std::vector( + {for_vars[0], + alloc(for_vars[1], alloc(cumulative_input_sizes))}), + alloc( + inputs[i].data(), + std::vector({for_vars[0], for_vars[1]}))); + auto for_st = alloc( for_vars[0], - new IntImm(0), - new IntImm(input_sizes_[i][0]), - new For( + alloc(0), + alloc(input_sizes_[i][0]), + alloc( for_vars[1], - new IntImm(0), - new IntImm(input_sizes_[i][1]), + alloc(0), + alloc(input_sizes_[i][1]), store)); for_stmts[i] = for_st; cumulative_input_sizes += input_sizes_[i][1]; } - auto output = new Tensor(output_buf, new Block(for_stmts)); + auto output = Tensor(output_buf, alloc(for_stmts)); LoopNest nest({output}); nest.prepareForCodegen(); nest.vectorizeInnerLoops(); - Stmt* s = IRSimplifier::simplify(nest.root_stmt()); + StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); std::vector buf_args(inputs.begin(), inputs.end()); buf_args.push_back(output); LLVMCodeGen cg(s, buf_args); diff --git a/benchmarks/cpp/tensorexpr/bench_gemm.cpp b/benchmarks/cpp/tensorexpr/bench_gemm.cpp index 792d457c2f23a..ec13b09025eea 100644 --- a/benchmarks/cpp/tensorexpr/bench_gemm.cpp +++ b/benchmarks/cpp/tensorexpr/bench_gemm.cpp @@ -40,11 +40,10 @@ BENCHMARK_DEFINE_F(Gemm, Torch)(benchmark::State& state) { } BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { - te::KernelScope ks; te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); - te::Tensor* CT = te::Reduce( + te::Tensor CT = te::Reduce( "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), @@ -54,7 +53,7 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { {{K, "K"}}); te::LoopNest loop({CT}); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -64,11 +63,10 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprNoopt)(benchmark::State& state) { } BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { - te::KernelScope ks; te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); - te::Tensor* CT = te::Reduce( + te::Tensor CT = te::Reduce( "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), @@ -80,41 +78,41 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 32); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 32); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -124,11 +122,10 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile32x32)(benchmark::State& state) { } BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { - te::KernelScope ks; te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); - te::Tensor* CT = te::Reduce( + te::Tensor CT = te::Reduce( "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), @@ -140,41 +137,41 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -184,11 +181,10 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16)(benchmark::State& state) { } BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { - te::KernelScope ks; te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); - te::Tensor* CT = te::Reduce( + te::Tensor CT = te::Reduce( "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), @@ -200,49 +196,49 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[3]; - te::For* ni = loops[4]; - te::Stmt* unrolled; + te::ForPtr mi = loops[3]; + te::ForPtr ni = loops[4]; + te::StmtPtr unrolled; loop.vectorize(ni); loop.unroll(mi, &unrolled); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); @@ -252,11 +248,10 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16VecUnroll)(benchmark::State& state) { } BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { - te::KernelScope ks; te::Placeholder AP(te::BufHandle("A", {M, K}, te::kFloat)); te::Placeholder BP(te::BufHandle("B", {K, N}, te::kFloat)); - te::Tensor* CT = te::Reduce( + te::Tensor CT = te::Reduce( "gemm", {{M, "M"}, {N, "N"}}, te::Sum(), @@ -268,45 +263,45 @@ BENCHMARK_DEFINE_F(Gemm, TensorExprTile4x16Cache)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* m = loops[0]; + te::ForPtr m = loops[0]; loop.splitWithMask(m, 4); } { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* n = loops[2]; + te::ForPtr n = loops[2]; loop.splitWithMask(n, 16); } // mo, mi, no, ni, k -> // mo, no, mi, ni, k { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[1]; - te::For* no = loops[2]; + te::ForPtr mi = loops[1]; + te::ForPtr no = loops[2]; loop.reorderAxis(mi, no); } // mo, no, mi, ni, k -> // mo, no, mi, k, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* ni = loops[3]; - te::For* k = loops[4]; + te::ForPtr ni = loops[3]; + te::ForPtr k = loops[4]; loop.reorderAxis(ni, k); } // mo, no, mi, k, ni -> // mo, no, k, mi, ni { auto const& loops = loop.getLoopStmtsFor(CT); - te::For* mi = loops[2]; - te::For* k = loops[3]; + te::ForPtr mi = loops[2]; + te::ForPtr k = loops[3]; loop.reorderAxis(mi, k); } { auto const& loops = loop.getLoopStmtsFor(CT); - loop.cacheAccesses(CT->buf(), "C_regs", loops[2]); + loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BP, CT}); diff --git a/benchmarks/cpp/tensorexpr/bench_ops.py b/benchmarks/cpp/tensorexpr/bench_ops.py index ca40e5d3c7459..12d766ae74862 100644 --- a/benchmarks/cpp/tensorexpr/bench_ops.py +++ b/benchmarks/cpp/tensorexpr/bench_ops.py @@ -59,7 +59,7 @@ def hardswish(x): traced(x) # Validate result. - torch.testing.assert_allclose(op(x), traced(x)) + torch.testing.assert_close(op(x), traced(x)) # Benchmark. bench_iters = 100 @@ -94,7 +94,7 @@ def test_batch_norm(): traced(x, y, z) # Validate result. - torch.testing.assert_allclose(op(x, y, z), traced(x, y, z)) + torch.testing.assert_close(op(x, y, z), traced(x, y, z)) # Benchmark. bench_iters = 100 diff --git a/benchmarks/cpp/tensorexpr/bench_parallel.cpp b/benchmarks/cpp/tensorexpr/bench_parallel.cpp index fee326cdd4bd4..178a8795edd03 100644 --- a/benchmarks/cpp/tensorexpr/bench_parallel.cpp +++ b/benchmarks/cpp/tensorexpr/bench_parallel.cpp @@ -35,19 +35,18 @@ class ParallelAdd : public benchmark::Fixture { }; BENCHMARK_DEFINE_F(ParallelAdd, Simple)(benchmark::State& state) { - KernelScope kernel_scope; Placeholder a_buf("a", kFloat, {M}); Placeholder b_buf("b", kFloat, {M}); - Tensor* c_tensor = Compute( + Tensor c_tensor = Compute( "c", {{M, "m"}}, [&](const VarHandle& m) { return a_buf.load(m) + b_buf.load(m); }); LoopNest loop_nest({c_tensor}); auto const& loops = loop_nest.getLoopStmtsFor(c_tensor); - For* m = loops[0]; + ForPtr m = loops[0]; m->set_parallel(); loop_nest.prepareForCodegen(); - Stmt* stmt = loop_nest.root_stmt(); + StmtPtr stmt = loop_nest.root_stmt(); LLVMCodeGen cg(stmt, {c_tensor, a_buf, b_buf}); float* a_ptr = A.data_ptr(); diff --git a/benchmarks/cpp/tensorexpr/bench_reduce.cpp b/benchmarks/cpp/tensorexpr/bench_reduce.cpp index acd46ac1de410..e053317feca60 100644 --- a/benchmarks/cpp/tensorexpr/bench_reduce.cpp +++ b/benchmarks/cpp/tensorexpr/bench_reduce.cpp @@ -217,12 +217,11 @@ BENCHMARK_REGISTER_F(Reduce1D, NativeTiled)->Args({1 << 24}); #endif // USE_AVX2 BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) { - te::KernelScope ks; int M = A.numel(); te::Placeholder AP(te::BufHandle("A", {M}, te::kFloat)); - te::Tensor* BT = te::Reduce( + te::Tensor BT = te::Reduce( "reduce_full", {{1, "N"}}, te::Sum(), @@ -233,7 +232,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) { te::LoopNest loop({BT}); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -250,12 +249,11 @@ BENCHMARK_DEFINE_F(Reduce1D, TeNaive)(benchmark::State& state) { BENCHMARK_REGISTER_F(Reduce1D, TeNaive)->Args({1 << 24}); BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) { - te::KernelScope ks; int M = A.numel(); te::Placeholder AP(te::BufHandle("A", {M}, te::kFloat)); - te::Tensor* BT = te::Reduce( + te::Tensor BT = te::Reduce( "reduce_full", {{1, "N"}}, te::Sum(), @@ -269,12 +267,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); - te::For* m = loops[1]; + te::ForPtr m = loops[1]; loop.splitWithTail(m, kChunkSize); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -291,12 +289,11 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitTail)(benchmark::State& state) { BENCHMARK_REGISTER_F(Reduce1D, TeSplitTail)->Args({1 << 24}); BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) { - te::KernelScope ks; int M = A.numel(); te::Placeholder AP(te::BufHandle("A", {M}, te::kFloat)); - te::Tensor* BT = te::Reduce( + te::Tensor BT = te::Reduce( "reduce_full", {{1, "N"}}, te::Sum(), @@ -310,12 +307,12 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) { { auto const& loops = loop.getLoopStmtsFor(BT); - te::For* m = loops[1]; + te::ForPtr m = loops[1]; loop.splitWithMask(m, kChunkSize); } loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -332,14 +329,13 @@ BENCHMARK_DEFINE_F(Reduce1D, TeSplitMask)(benchmark::State& state) { BENCHMARK_REGISTER_F(Reduce1D, TeSplitMask)->Args({1 << 24}); BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { - te::KernelScope ks; int M = A.numel(); const int kChunkSize = 8; TORCH_CHECK(M % kChunkSize == 0); te::Placeholder AP(te::BufHandle("A", {M}, te::kFloat)); - te::Tensor* BT = te::Reduce( + te::Tensor BT = te::Reduce( "reduce_full", {}, te::Sum(), @@ -349,17 +345,17 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { {{M, "M"}}); te::LoopNest loop({BT}); - te::Buf* rfac_buf; + te::BufPtr rfac_buf; auto loops = loop.getLoopStmtsFor(BT); TORCH_CHECK(loops.size() == 1); - te::For* mi; + te::ForPtr mi; loop.splitWithMask(loops.at(0), kChunkSize, &mi); - te::For* mo = loops.at(0); + te::ForPtr mo = loops.at(0); loop.reorderAxis(mo, mi); loops = loop.getLoopStmtsFor(BT); - auto bt_body = const_cast(loop.getAllWritesToBuf(BT->buf())[1]); + auto bt_body = loop.getAllWritesToBuf(BT.buf())[1]; TORCH_CHECK(loop.rfactor(bt_body, loops.at(0), &rfac_buf)); loop.reorderAxis(loops.at(0), loops.at(1)); @@ -368,7 +364,7 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { loop.vectorize(loops.at(1)); loop.prepareForCodegen(); - te::Stmt* s = loop.root_stmt(); + te::StmtPtr s = loop.root_stmt(); s = te::IRSimplifier::simplify(s); auto cg = CreateCodeGen("llvm_codegen", s, {AP, BT}); @@ -385,17 +381,16 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { BENCHMARK_REGISTER_F(Reduce1D, TeRfactorV1)->Args({1 << 24}); BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) { - te::KernelScope ks; const int M = A.numel(); const int kChunkSize = 8; te::Placeholder a("A", te::kFloat, {M}); - te::Tensor* b = te::computeSum({a.handle(), te::IntList({0}), false}, at::kFloat); + te::Tensor b = te::computeSum({a.handle(), te::IntList({0}), false}, at::kFloat); te::LoopNest nest({b}); auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[0], kChunkSize, &mi); loops = nest.reorder({loops[0], mi}, {1, 0}); nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf); @@ -450,10 +445,9 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch) ->Args({1 << 12, 1 << 12}); BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) { - te::KernelScope ks; constexpr int kCacheSize = 1 << 12; te::Placeholder a("A", te::kFloat, {M, N}); - te::Tensor* b = te::computeSum({a.handle(), te::IntList({0}), false}, at::kFloat); + te::Tensor b = te::computeSum({a.handle(), te::IntList({0}), false}, at::kFloat); te::LoopNest nest({b}); auto sch = state.range(2); @@ -557,17 +551,16 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand) ->Args({1 << 18, 1 << 6}); BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) { - te::KernelScope ks; constexpr int kChunkSize = 8; te::Placeholder a("A", te::kFloat, {M, N}); - te::Tensor* b = te::computeSum({a.handle(), te::IntList({1}), false}, at::kFloat); + te::Tensor b = te::computeSum({a.handle(), te::IntList({1}), false}, at::kFloat); te::LoopNest nest({b}); auto sch = state.range(2); if (sch == 1) { auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[1], kChunkSize, &mi); loops = nest.reorder({loops[1], mi}, {1, 0}); TORCH_CHECK(nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf)); @@ -583,8 +576,8 @@ BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) { nest.reorderAxis(loops[1], loops[2]); } else if (sch == 3) { auto loops = nest.getLoopStmtsFor(b); - te::For *mi, *mo; - te::Buf *rf; + te::ForPtr mi, mo; + te::BufPtr rf; nest.splitWithMask(loops[1], kChunkSize, &mi); loops = nest.reorder({loops[1], mi}, {1, 0}); TORCH_CHECK(nest.rfactor(nest.getLoopBodyFor(b), loops[0], &rf)); diff --git a/benchmarks/cpp/tensorexpr/bench_signed_log1p.cpp b/benchmarks/cpp/tensorexpr/bench_signed_log1p.cpp new file mode 100644 index 0000000000000..44781f58c9027 --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_signed_log1p.cpp @@ -0,0 +1,120 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +namespace { + +class SignedLog1pBench : public benchmark::Fixture { + public: + void SetUp(const benchmark::State& state) override { + input_size_ = {state.range(0), state.range(1)}; + input_size_int_ = {state.range(0), state.range(1)}; + input_ = torch::rand(input_size_); + ref_ = signedLog1p(input_); + } + + void TearDown(benchmark::State& state) override { + TORCH_CHECK(at::allclose(ref_, output_)); + state.counters["GB/s"] = benchmark::Counter( + uint64_t(state.iterations()) * 2 * output_.nbytes(), + benchmark::Counter::kIsRate); + } + + at::Tensor signedLog1p(const at::Tensor& inp) { + auto sign = at::sign(inp); + auto log1p = at::log1p(at::abs(inp)); + return sign * log1p; + } + + void runATen(benchmark::State& state) { + for (auto _ : state) { + output_ = signedLog1p(input_); + } + } + + void runNNC(benchmark::State& state) { + Placeholder input_ph( + "input", kFloat, {input_size_int_[0], input_size_int_[1]}); + Tensor abs_result = Compute( + "aten_abs", + {{input_size_int_[0], "M"}, {input_size_int_[1], "N"}}, + [&](const VarHandle& m, const VarHandle& n) { + return abs(input_ph.load(m, n)); + }); + Tensor log1p_result = Compute( + "aten_log1p", + {{input_size_int_[0], "M"}, {input_size_int_[1], "N"}}, + [&](const VarHandle& m, const VarHandle& n) { + return log1p(abs_result.load(m, n)); + }); + Tensor sign = Compute( + "aten_sign", + {{input_size_int_[0], "M"}, {input_size_int_[1], "N"}}, + [&](const VarHandle& m, const VarHandle& n) { + return CompareSelect::make( + input_ph.load(m, n), + ExprHandle(0.0f), + ExprHandle(-1), + ExprHandle(1), + kLT); + }); + Tensor output = Compute( + "aten_mul", + {{input_size_int_[0], "M"}, {input_size_int_[1], "N"}}, + [&](const VarHandle& m, const VarHandle& n) { + return sign.load(m, n) * log1p_result.load(m, n); + }); + LoopNest nest({output}, {abs_result, log1p_result, sign, output}); + GRAPH_DEBUG("Original Stmt: ", *nest.root_stmt()); + nest.inlineIntermediateBufs(true); + nest.prepareForCodegen(); + nest.simplify(); + nest.vectorizeInnerLoops(); + nest.simplify(); + GRAPH_DEBUG("Final stmt: ", *nest.root_stmt()); + + // StmtPtr s = IRSimplifier::simplify(nest.root_stmt()); + std::vector buf_args; + buf_args.push_back(input_ph); + buf_args.push_back(output); + LLVMCodeGen cg(nest.root_stmt(), buf_args); + + std::vector call_args; + for (auto _ : state) { + output_ = at::empty_like(ref_); + call_args.clear(); + call_args.push_back(input_.data_ptr()); + call_args.push_back(output_.data_ptr()); + cg.call(call_args); + } + } + + private: + std::vector input_size_; + std::vector input_size_int_; + at::Tensor input_; + at::Tensor output_; + at::Tensor ref_; +}; + +} // namespace + +BENCHMARK_DEFINE_F(SignedLog1pBench, ATen)(benchmark::State& state) { + runATen(state); +} + +BENCHMARK_DEFINE_F(SignedLog1pBench, NNC)(benchmark::State& state) { + runNNC(state); +} + +BENCHMARK_REGISTER_F(SignedLog1pBench, ATen)->Args({10, 1467}); + +BENCHMARK_REGISTER_F(SignedLog1pBench, NNC)->Args({10, 1467}); diff --git a/benchmarks/operator_benchmark/pt/matrix_mult_test.py b/benchmarks/operator_benchmark/pt/matrix_mult_test.py new file mode 100644 index 0000000000000..ad7d42318140d --- /dev/null +++ b/benchmarks/operator_benchmark/pt/matrix_mult_test.py @@ -0,0 +1,119 @@ +import operator_benchmark as op_bench +import torch + +""" +Microbenchmarks for batch matrix mult with einsum and torch.bmm. +""" + +batch_mm_configs_short = op_bench.config_list( + attr_names=["B", "M", "N", "K"], + attrs=[ + [4, 5, 3, 2], + [32, 25, 20, 30], + [128, 100, 120, 110], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"], +) + +batch_mm_configs_long = op_bench.config_list( + attr_names=["B", "M", "N", "K"], + attrs=[ + [128, 256, 128, 256], + [512, 1024, 1024, 512], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["long"], +) + +batch_mm_op_list = op_bench.op_list( + attr_names=['op_name', 'op_func'], + attrs=[ + ['einsum_bmm', torch.einsum], + ['bmm', torch.bmm], + ], +) + +class BatchMatrixMultBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, K, device, op_func): + self.inputs = { + "input_one": torch.rand(B, M, N, device=device), + "input_two": torch.rand(B, N, K, device=device) + } + self.op_func = op_func + + def forward(self, input_one, input_two): + if self.op_func.__name__ == "einsum": + return torch.einsum('bij,bjk->bik', input_one, input_two) + else: + return torch.bmm(input_one, input_two) + + +""" +Microbenchmarks for element-wise matrix mult with einsum and torch.mul. +""" + +batch_elementwise_configs_short = op_bench.config_list( + attr_names=["B", "M", "N"], + attrs=[ + [4, 5, 3], + [32, 25, 20], + [100, 90, 110], + ], + cross_product_configs={ + 'device': ['cpu', 'cuda'], + }, + tags=["short"], +) + + +batch_elementwise_configs_long = op_bench.cross_product_configs( + B=[128, 512, 1024], + M=[128, 512, 1024], + N=[128, 512, 1024], + device=['cpu', 'cuda'], + tags=['long'] +) + +batch_elementwise_op_list = op_bench.op_list( + attr_names=['op_name', 'op_func'], + attrs=[ + ['einsum_elementwise', torch.einsum], + ['mul', torch.mul], + ], +) + +class BatchElementWiseBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, device, op_func): + self.inputs = { + "input_one": torch.rand(B, M, N, device=device), + "input_two": torch.rand(B, M, N, device=device) + } + self.op_func = op_func + + def forward(self, input_one, input_two): + if self.op_func.__name__ == "einsum": + return torch.einsum('bij,bij->bij', input_one, input_two) + else: + return torch.mul(input_one, input_two) + + +op_bench.generate_pt_tests_from_op_list( + batch_mm_op_list, + batch_mm_configs_short + batch_mm_configs_long, + BatchMatrixMultBenchmark, +) + +op_bench.generate_pt_tests_from_op_list( + batch_elementwise_op_list, + batch_elementwise_configs_short + batch_elementwise_configs_long, + BatchElementWiseBenchmark, +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/static_runtime/test_scripts.h b/benchmarks/static_runtime/test_scripts.h index 6045a1c2f9772..b17ddeda45dff 100644 --- a/benchmarks/static_runtime/test_scripts.h +++ b/benchmarks/static_runtime/test_scripts.h @@ -26,6 +26,11 @@ alias of the model output. */ +const auto abs_script = R"JIT( + def forward(self, a): + return a.abs().clone() +)JIT"; + const auto list_construct_script = R"JIT( def forward(self, a, b): return [a, b] @@ -133,6 +138,22 @@ const auto reshape_inplace_script = R"JIT( return (d, e, f) )JIT"; +const auto reshape_inplace_script_1 = R"JIT( + def forward(self, inp: Tensor, shape: List[int], flag: bool): + if flag: + a = inp + inp + b = a.reshape(shape) + c = b.sigmoid() + else: + a = inp * inp + b = a.sigmoid_() + c = b.reshape(shape) + d = c + c + e = a + a + f = b + b + return (d, e, f) +)JIT"; + const auto sigmoid_inplace_script = R"JIT( def forward(self, inp: Tensor): a = torch.sigmoid(inp, out=inp).clone() @@ -286,6 +307,18 @@ const auto to_script_4 = R"JIT( return (c) )JIT"; +const auto detach_script_0 = R"JIT( + def forward(self, input: Tensor): + a = input.detach() + return input is a +)JIT"; + +const auto detach_script_1 = R"JIT( + def forward(self, input: Tensor): + a = input.detach() + return a.clone() +)JIT"; + const std::string embedding_bag_default = R"JIT( def forward(self, a: Tensor, b: Tensor, c: Tensor): return torch.embedding_bag(a, b, c) @@ -316,6 +349,12 @@ const std::string embedding_bag_max_last_offset = R"JIT( return torch.embedding_bag(a, b, c, False, 2, False, None, True) )JIT"; +const auto expand_as_script = R"JIT( + def forward(self, input: Tensor, other:Tensor): + a = input.expand_as(other) + return a.clone() +)JIT"; + const auto sign_tensor = R"JIT( def forward(self, input: Tensor): return torch.sign(input).clone() @@ -570,6 +609,11 @@ const auto var_cat_script = R"JIT( return torch.cat([inp1, inp2], dim).clone() )JIT"; +const auto var_stack_script = R"JIT( + def forward(self, inp1: Tensor, inp2: Tensor, dim: int): + return torch.stack([inp1, inp2], dim).clone() +)JIT"; + const auto isinstance_int_script = R"JIT( def forward(self, a: Any): return isinstance(a, int) @@ -632,24 +676,46 @@ const auto argmin_with_keep_dim_script = R"JIT( return torch.argmin(a, dim, True).clone() )JIT"; -const auto getitem_tensor_script = R"JIT( +const auto softmax_script = R"JIT( + def forward(self, a: Tensor, dim: int): + return torch.softmax(a, dim).clone() +)JIT"; + +const auto softmax_script_with_dtype = R"JIT( + def forward(self, a: Tensor, dim: int, dtype: int): + return torch.softmax(a, dim, dtype=dtype).clone() +)JIT"; + +const auto getitem_dict_tensor_script = R"JIT( def forward(self, key: Tensor): d = {key: 1} return d[key] )JIT"; -const auto getitem_int_script = R"JIT( +const auto getitem_dict_int_script = R"JIT( def forward(self, key: int): d = {key: 1} return d[key] )JIT"; -const auto getitem_str_script = R"JIT( +const auto getitem_dict_str_script = R"JIT( def forward(self, key: str): d = {key: 1} return d[key] )JIT"; +const auto getitem_list_int_script = R"JIT( + def forward(self, idx: int): + lst = [1, 2, 3] + return lst[idx] +)JIT"; + +const auto getitem_list_tensor_script = R"JIT( + def forward(self, tensor: Tensor, idx: int): + lst = [tensor, tensor] + return lst[idx] +)JIT"; + const auto transpose_script = R"JIT( def forward(self, a: Tensor, dim1: int, dim2: int): return torch.transpose(a, dim1, dim2).clone() @@ -695,3 +761,80 @@ const auto append_tensor_script = R"JIT( lst.append(a) return lst )JIT"; + +const auto nonzero_tensor = R"JIT( + def forward(self, input: Tensor): + a = torch.nonzero(input).clone() + return (a) +)JIT"; + +const std::string quantize_script = R"IR( + graph(%input: Tensor, %weights: Tensor): + %scale: float = prim::Constant[value=1.]() + %zero_point: int = prim::Constant[value=1]() + %bias: None = prim::Constant() + %packed_params = quantized::linear_prepack(%weights, %bias) + %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point) + %1249: Tensor = aten::dequantize(%1254) + return (%1249) +)IR"; + +const auto fmod_tensor = R"JIT( + def forward(self, a: Tensor, b: Tensor): + return torch.fmod(a, b).clone() +)JIT"; + +const auto fmod_scalar = R"JIT( + def forward(self, a: Tensor, b: int): + return torch.fmod(a, b).clone() +)JIT"; + +const std::string embedding_bag_byte_prepack_script = R"IR( + graph(%input: Tensor): + %none : None = prim::Constant() + %output: Tensor = quantized::embedding_bag_byte_prepack(%input) + %res: Tensor = aten::clone(%output, %none) + return (%res) +)IR"; + +const auto linalg_norm_ord_scalar = R"JIT( + def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int): + return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() +)JIT"; + +const auto linalg_norm_ord_str = R"JIT( + def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int): + return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone() +)JIT"; + +const std::string cat_script = R"IR( + graph(%a: Tensor, %b: Tensor, %dim: int): + %ten_list: Tensor[] = prim::ListConstruct(%a, %b) + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=1]() + %3 : int = prim::Constant[value=1]() + %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3) + %ret: Tensor = aten::cat(%ten_list2, %dim) + return (%ret) +)IR"; + +const auto cumsum_script = R"JIT( + def forward(self, a: Tensor, dim: int): + return torch.cumsum(a, dim).clone() +)JIT"; + +const auto cumsum_script_dtype = R"JIT( + def forward(self, a: Tensor, dim: int, dtype: int): + return torch.cumsum(a, dim, dtype=dtype).clone() +)JIT"; + +const std::string signed_log1p_script = R"IR( + graph(%input): + %0 : Tensor = aten::sign(%input) + %1 : Tensor = aten::abs(%input) + %2 : Tensor = aten::log1p(%1) + %3 : Tensor = aten::mul(%0, %2) + %none : NoneType = prim::Constant() + %res : Tensor = aten::clone(%3, %none) + return (%res) +)IR"; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 7af49d6c8fa63..5eb3dfe28bd84 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -69,6 +71,7 @@ Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) { TEST(StaticRuntime, InPlace) { EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script)); + EXPECT_TRUE(testHasInplaceOp(reshape_inplace_script_1)); EXPECT_TRUE(testHasInplaceOp(sigmoid_inplace_script)); EXPECT_FALSE(testHasInplaceOp(sigmoid_out_script)); } @@ -208,6 +211,13 @@ TEST(StaticRuntime, EmbeddingBag) { } TEST(StaticRuntime, LayerNorm) { +#ifdef FBCODE_CAFFE2 + script::Module module("module"); + module.define(layer_norm_with_weights); + torch::jit::StaticModule smodule(module); + ASSERT_EQ(getNodeWithKind(smodule, "aten::layer_norm"), nullptr); + ASSERT_NE(getNodeWithKind(smodule, "static_runtime::layer_norm"), nullptr); +#endif const auto a = torch::rand({1, 2, 2, 2}); const auto b = torch::rand({3, 2, 2, 2}); for (int normalized_size : {2, 3}) { @@ -256,6 +266,15 @@ TEST(StaticRuntime, Addmm) { testStaticRuntime(addmm_script, args, args1); } +TEST(StaticRuntime, IndividualOps_Abs) { + auto a = at::randn({2, 3}); + auto b = at::randn({4, 2, 3}); + std::vector args{a}; + std::vector args2{b}; + testStaticRuntime(abs_script, args); + testStaticRuntime(abs_script, args, args2); +} + TEST(StaticRuntime, IndividualOps_Binary) { auto a = at::randn({2, 3}); auto b = at::ones({2, 3}); @@ -589,6 +608,28 @@ TEST(StaticRuntime, IndividualOps_to) { test_to(at::ScalarType::Half, false, true, c10::MemoryFormat::ChannelsLast); } +TEST(StaticRuntime, IndividualOps_Detach) { + auto a = at::randn({4, 3, 1, 2}); + auto b = at::randn({3, 2, 2}); + std::vector args{a}; + std::vector args2{b}; + testStaticRuntime(detach_script_0, args); + testStaticRuntime(detach_script_0, args, args2); + testStaticRuntime(detach_script_1, args); + testStaticRuntime(detach_script_1, args, args2); +} + +TEST(StaticRuntime, IndividualOps_ExpandAs) { + auto a = at::randn({3,1}); + auto b = at::randn({3,2}); + auto c = at::randn({4,1}); + auto d = at::randn({4,2}); + std::vector args{a, b}; + std::vector args2{c, d}; + testStaticRuntime(expand_as_script, args); + testStaticRuntime(expand_as_script, args, args2); +} + TEST(StaticRuntime, IndividualOps_Full) { auto dtype = at::ScalarType::Int; auto cpu = at::Device(DeviceType::CPU); @@ -1043,19 +1084,45 @@ TEST(StaticRuntime, IndividualOps_Argmin) { testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b); } -TEST(StaticRuntime, IndividualOps_GetItem) { +TEST(StaticRuntime, IndividualOps_Softmax) { + auto a = at::randn({2, 3}); + auto b = at::randn({3, 3, 3}); + + testStaticRuntime(softmax_script, {a, 0}); + testStaticRuntime(softmax_script, {a, 1}); + + testStaticRuntime(softmax_script, {b, 0}); + testStaticRuntime(softmax_script, {b, 1}); + testStaticRuntime(softmax_script, {b, 2}); + + testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float}); + testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float}); +} + +TEST(StaticRuntime, IndividualOps_GetItem_Dict) { int int_key = 0; std::string str_key = "str"; // No need to test these multiple times, args are not tensors - testStaticRuntime(getitem_int_script, {int_key}); - testStaticRuntime(getitem_str_script, {str_key}); + testStaticRuntime(getitem_dict_int_script, {int_key}); + testStaticRuntime(getitem_dict_str_script, {str_key}); auto a = torch::tensor({1}); auto b = torch::tensor({1, 1}); - testStaticRuntime(getitem_tensor_script, {a}); - testStaticRuntime(getitem_tensor_script, {a}, {b}); + testStaticRuntime(getitem_dict_tensor_script, {a}); + testStaticRuntime(getitem_dict_tensor_script, {a}, {b}); +} + +TEST(StaticRuntime, IndividualOps_GetItem_List) { + testStaticRuntime(getitem_list_int_script, {1}); + testStaticRuntime(getitem_list_int_script, {-1}); + + auto a = torch::tensor({1}); + auto b = torch::tensor({1, 1}); + + testStaticRuntime(getitem_list_tensor_script, {a, 1}); + testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1}); } TEST(StaticRuntime, IndividualOps_Transpose) { @@ -1150,3 +1217,150 @@ TEST(StaticRuntime, IndividualOps_Append) { testStaticRuntime(append_tensor_script, args_tensor); testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large); } + +TEST(StaticRuntime, QuantizedLinear) { + at::Tensor weight = + at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8); + at::Tensor input = + at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8); + + at::Tensor weight_2 = + at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQInt8); + at::Tensor input_2 = + at::quantize_per_tensor(torch::randn({4, 3}), 2, 3, torch::kQUInt8); + + testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2}); +} + +TEST(StaticRuntime, IndividualOps_VarStack) { + // 2D tensors - stack dim = 0 + std::vector args1 = {at::randn({6, 6}), at::randn({6, 6}), 0}; + testStaticRuntime(var_stack_script, args1); + + // 3D tensors - stack dim = 1 + std::vector args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1}; + testStaticRuntime(var_stack_script, args2); + + // 3D tensors - stack dim = 2 + std::vector args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2}; + testStaticRuntime(var_stack_script, args3); + + testStaticRuntime(var_stack_script, args1, args2); +} + +TEST(StaticRuntime, IndividualOps_FmodTensor) { + // fmod tensor version + auto a = at::randn({2, 3}); + auto b = at::randn({2, 3}); + std::vector args0{a, b}; + testStaticRuntime(fmod_tensor, args0); + + // check for dynamic shapes + auto c = at::randn({4, 3, 2}); + auto d = at::randn({4, 3, 2}); + std::vector args1{c, d}; + testStaticRuntime(fmod_tensor, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_FmodScalar) { + auto a = at::randn({2, 3}); + + // fmod scalar version + std::vector args2{a, 3}; + testStaticRuntime(fmod_scalar, args2); + + // check for dynamic shapes + auto c = at::randn({4, 3, 2}); + std::vector args3{c, 4}; + testStaticRuntime(fmod_scalar, args2, args3); +} + +TEST(StaticRuntime, QEmbeddingBagByteUnpack) { + auto a = torch::randn({8, 16}, at::ScalarType::Float); + auto b = torch::randn({8*2, 16*2}, at::ScalarType::Float); + + testStaticRuntime(embedding_bag_byte_prepack_script, {a}); + testStaticRuntime(embedding_bag_byte_prepack_script, {a},{b}); +} + +TEST(StaticRuntime, IndividualOps_LinalgNorm_ScalarOrd) { + auto a = at::randn({2, 3}); + auto dim = std::vector({1}); + auto dtype = at::ScalarType::Float; + + std::vector args0{a, 4, dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_scalar, args0); + + auto b = at::randn({4, 5}); + std::vector args1{b, 4, dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_scalar, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_LinalgNorm_StringOrd) { + auto a = at::randn({2, 3}); + auto dim = std::vector({0, 1}); + auto dtype = at::ScalarType::Float; + + std::vector args0{a, "fro", dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_str, args0); + + auto b = at::randn({4, 5}); + std::vector args1{b, "fro", dim, true, dtype}; + testStaticRuntime(linalg_norm_ord_str, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_Cat) { + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(cat_script, graph.get(), vmap); + torch::jit::StaticModule smodule(graph); + ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat")); + + auto a = at::randn({2, 4}); + auto b = at::randn({3, 4}); + std::vector args0{a, b, 0}; + + testStaticRuntime(cat_script, args0); + + auto c = at::randn({3, 4}); + auto d = at::randn({3, 5}); + std::vector args1{c, d, 1}; + testStaticRuntime(cat_script, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_Cumsum) { + auto a = at::randn({2, 3}); + std::vector args0{a, 0}; + testStaticRuntime(cumsum_script, args0); + + auto b = at::randn({4, 3}); + std::vector args1{b, 1}; + testStaticRuntime(cumsum_script, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_CumsumDtype) { + auto a = at::randn({1, 2}); + auto dtype = at::ScalarType::Float; + std::vector args0{a, 0, dtype}; + testStaticRuntime(cumsum_script_dtype, args0); + + auto b = at::randn({3, 4}); + std::vector args1{b, 1, dtype}; + testStaticRuntime(cumsum_script_dtype, args0, args1); +} + +TEST(StaticRuntime, IndividualOps_Nonzero) { + auto a = at::randint(0, 2, {2, 3}); + testStaticRuntime(nonzero_tensor, {a}); + + auto b = at::randint(0, 2, {4, 3, 2}); + testStaticRuntime(nonzero_tensor, {a}, {b}); +} + +TEST(StaticRuntime, SignedLog1p) { + std::vector args1 = {at::randn({2, 2})}; + testStaticRuntime(signed_log1p_script, args1, {}, true); + + std::vector args2 = {at::randn({3, 3, 3})}; + testStaticRuntime(signed_log1p_script, args1, args2, true); +} diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index b0e1ae06be8d8..7690e356adaa0 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #ifdef _WIN32 diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 3d2d4352ffef4..23a0e024d35ed 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.5 FATAL_ERROR) +cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(c10 CXX) set(CMAKE_CXX_STANDARD 14) diff --git a/c10/core/AutogradState.cpp b/c10/core/AutogradState.cpp new file mode 100644 index 0000000000000..4667acb435193 --- /dev/null +++ b/c10/core/AutogradState.cpp @@ -0,0 +1,21 @@ +#include + +namespace c10 { + +namespace { +// By default, grad mode is enabled and inference mode is disabled +thread_local AutogradState autograd_state_tls = AutogradState( + /* grad_mode */ true, + /* inference_mode */ false, + /* fw_grad_mode */ true); +} // namespace + +AutogradState& AutogradState::get_tls_state() { + return autograd_state_tls; +} + +void AutogradState::set_tls_state(AutogradState state) { + autograd_state_tls = state; +} + +} // namespace c10 diff --git a/c10/core/AutogradState.h b/c10/core/AutogradState.h new file mode 100644 index 0000000000000..a1d13a42891da --- /dev/null +++ b/c10/core/AutogradState.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include + +namespace c10 { + +// Structure used to pack all the thread local boolean +// flags used by autograd +struct C10_API AutogradState { + static AutogradState& get_tls_state(); + static void set_tls_state(AutogradState state); + + AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode) + : grad_mode_(grad_mode), + inference_mode_(inference_mode), + fw_grad_mode_(fw_grad_mode) {} + + void set_grad_mode(bool enabled) { + grad_mode_ = enabled; + } + + void set_fw_grad_mode(bool enabled) { + fw_grad_mode_ = enabled; + } + + void set_inference_mode(bool enabled) { + inference_mode_ = enabled; + } + + bool get_grad_mode() const { + return grad_mode_; + } + + bool get_fw_grad_mode() const { + return fw_grad_mode_; + } + + bool get_inference_mode() const { + return inference_mode_; + } + + private: + bool grad_mode_ : 1; + bool inference_mode_ : 1; + bool fw_grad_mode_ : 1; +}; + +} // namespace c10 diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 2f071345311f2..e17a1bc4226c6 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -40,7 +40,7 @@ enum class Backend { SparseHIP, SparseVE, SparseXPU, - MSNPU, + ORT, XLA, Vulkan, Metal, @@ -66,8 +66,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::VE; } else if (t == DispatchKey::FPGA) { return Backend::FPGA; - } else if (t == DispatchKey::MSNPU) { - return Backend::MSNPU; + } else if (t == DispatchKey::ORT) { + return Backend::ORT; } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { return Backend::XLA; } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { @@ -123,8 +123,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::VE; case Backend::FPGA: return DispatchKey::FPGA; - case Backend::MSNPU: - return DispatchKey::MSNPU; + case Backend::ORT: + return DispatchKey::ORT; case Backend::XLA: return DispatchKey::XLA; case Backend::Lazy: @@ -178,8 +178,8 @@ static inline DeviceType backendToDeviceType(Backend b) { return DeviceType::VE; case Backend::FPGA: return DeviceType::FPGA; - case Backend::MSNPU: - return DeviceType::MSNPU; + case Backend::ORT: + return DeviceType::ORT; case Backend::XLA: return DeviceType::XLA; case Backend::Lazy: @@ -235,8 +235,8 @@ static inline const char* toString(Backend b) { return "FPGA"; case Backend::XPU: return "XPU"; - case Backend::MSNPU: - return "MSNPU"; + case Backend::ORT: + return "ORT"; case Backend::XLA: return "XLA"; case Backend::Lazy: diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index ee6f1b473fe08..2531e3942271a 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -4,28 +4,13 @@ #include #include +#include #include #include -#include #include #include #include -// Check if compiler has working std::regex implementation -// -// Test below is adapted from https://stackoverflow.com/a/41186162 -#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L -// Compiler has working regex. MSVC has erroneous __cplusplus. -#elif __cplusplus >= 201103L && \ - (!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \ - (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \ - defined(_GLIBCXX_REGEX_STATE_LIMIT) || \ - (defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4))) -// Compiler has working regex. -#else -static_assert(false, "Compiler does not have proper regex support."); -#endif - namespace c10 { namespace { DeviceType parse_type(const std::string& device_string) { @@ -43,7 +28,7 @@ DeviceType parse_type(const std::string& device_string) { {"hip", DeviceType::HIP}, {"ve", DeviceType::VE}, {"fpga", DeviceType::FPGA}, - {"msnpu", DeviceType::MSNPU}, + {"ort", DeviceType::ORT}, {"xla", DeviceType::XLA}, {"lazy", DeviceType::Lazy}, {"vulkan", DeviceType::Vulkan}, @@ -62,36 +47,87 @@ DeviceType parse_type(const std::string& device_string) { } TORCH_CHECK( false, - "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ", + "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ", device_string); } +enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR }; + } // namespace Device::Device(const std::string& device_string) : Device(Type::CPU) { TORCH_CHECK(!device_string.empty(), "Device string must not be empty"); - // We assume gcc 5+, so we can use proper regex. - static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?"); - std::smatch match; - TORCH_CHECK( - std::regex_match(device_string, match, regex), - "Invalid device string: '", - device_string, - "'"); - type_ = parse_type(match[1].str()); - if (match[2].matched) { - try { - index_ = c10::stoi(match[2].str()); - } catch (const std::exception&) { - TORCH_CHECK( - false, - "Could not parse device index '", - match[2].str(), - "' in device string '", - device_string, - "'"); + std::string device_name, device_index_str; + DeviceStringParsingState pstate = DeviceStringParsingState::START; + + // The code below tries to match the string in the variable + // device_string against the regular expression: + // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? + for (size_t i = 0; + pstate != DeviceStringParsingState::ERROR && i < device_string.size(); + ++i) { + const char ch = device_string.at(i); + switch (pstate) { + case DeviceStringParsingState::START: + if (ch != ':') { + if (isalpha(ch) || ch == '_') { + device_name.push_back(ch); + } else { + pstate = DeviceStringParsingState::ERROR; + } + } else { + pstate = DeviceStringParsingState::INDEX_START; + } + break; + + case DeviceStringParsingState::INDEX_START: + if (isdigit(ch)) { + device_index_str.push_back(ch); + pstate = DeviceStringParsingState::INDEX_REST; + } else { + pstate = DeviceStringParsingState::ERROR; + } + break; + + case DeviceStringParsingState::INDEX_REST: + if (device_index_str.at(0) == '0') { + pstate = DeviceStringParsingState::ERROR; + break; + } + if (isdigit(ch)) { + device_index_str.push_back(ch); + } else { + pstate = DeviceStringParsingState::ERROR; + } + break; + + case DeviceStringParsingState::ERROR: + // Execution won't reach here. + break; + } + } + + const bool has_error = device_name.empty() || + pstate == DeviceStringParsingState::ERROR || + (pstate == DeviceStringParsingState::INDEX_START && + device_index_str.empty()); + + TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'"); + + try { + if (!device_index_str.empty()) { + index_ = c10::stoi(device_index_str); } + } catch (const std::exception&) { + TORCH_CHECK( + false, + "Could not parse device index '", + device_index_str, + "' in device string '", + device_string, + "'"); } + type_ = parse_type(device_name); validate(); } diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 4ff939806f980..4635acdb148c2 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -25,8 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "ve" : "VE"; case DeviceType::FPGA: return lower_case ? "fpga" : "FPGA"; - case DeviceType::MSNPU: - return lower_case ? "msnpu" : "MSNPU"; + case DeviceType::ORT: + return lower_case ? "ort" : "ORT"; case DeviceType::XLA: return lower_case ? "xla" : "XLA"; case DeviceType::Lazy: @@ -75,7 +75,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::HIP: case DeviceType::VE: case DeviceType::FPGA: - case DeviceType::MSNPU: + case DeviceType::ORT: case DeviceType::XLA: case DeviceType::Lazy: case DeviceType::MLC: diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 2ae028d144026..c6bd56914d6d1 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -21,7 +21,7 @@ enum class DeviceType : int8_t { IDEEP = 5, // IDEEP. HIP = 6, // AMD HIP FPGA = 7, // FPGA - MSNPU = 8, // MSNPU + ORT = 8, // ONNX Runtime / Microsoft XLA = 9, // XLA / TPU Vulkan = 10, // Vulkan Metal = 11, // Metal @@ -42,7 +42,7 @@ constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kHIP = DeviceType::HIP; constexpr DeviceType kFPGA = DeviceType::FPGA; -constexpr DeviceType kMSNPU = DeviceType::MSNPU; +constexpr DeviceType kORT = DeviceType::ORT; constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kMLC = DeviceType::MLC; constexpr DeviceType kMeta = DeviceType::Meta; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 5c414484b38fd..18aa4fc32fb64 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -19,8 +19,8 @@ const char* toString(DispatchKey t) { return "FPGA"; case DispatchKey::XPU: return "XPU"; - case DispatchKey::MSNPU: - return "MSNPU"; + case DispatchKey::ORT: + return "ORT"; case DispatchKey::XLA: return "XLA"; case DispatchKey::Lazy: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 9f21838ddb4a3..07222b79ee964 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include @@ -59,8 +59,15 @@ enum class DispatchKey : uint8_t { // CUDA] FPGA, // Xilinx support lives out of tree at // https://gitlab.com/pytorch-complex/vitis_kernels - MSNPU, // unused externally, but tested at - // test/cpp_extensions/msnpu_extension.cpp + + // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and + // https://github.com/microsoft/onnxruntime, and is also used to test general + // backend/extension machinery in the core. cf: + // - test/cpp_extensions/ort_extension.cpp + // - test/test_torch.py + // - aten/src/ATen/test/extension_backend_test.cpp + ORT, + XLA, // lives out of tree at https://github.com/pytorch/xla MLC, // lives out of tree at https://github.com/pytorch/MLCompute Vulkan, @@ -114,7 +121,7 @@ enum class DispatchKey : uint8_t { // Here are reserved backends for user-defined backends, see Note [Private use // DispatchKey] - // To see some example about how to use this, check out MSNPU + // To see some example about how to use this, check out ORT PrivateUse1, PrivateUse2, PrivateUse3, diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index b796114d4a608..21433d4ace8d7 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -19,6 +19,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | DispatchKey::PrivateUse3, DispatchKey::MLC, DispatchKey::HPU, + DispatchKey::ORT, DispatchKey::Meta, }); @@ -31,8 +32,8 @@ bool isBackendDispatchKey(DispatchKey t) { // math_dispatch_keyset contains all keys in backend_dispatch_keyset and // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd // maps to math_dispatch_keyset. -constexpr DispatchKeySet math_dispatch_keyset = - backend_dispatch_keyset | autograd_dispatch_keyset; +constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | + autograd_dispatch_keyset | DispatchKeySet({DispatchKey::FuncTorchBatched}); DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 0d3a25ea9d8d1..b1f5f04524d19 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -248,7 +248,7 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet( {DispatchKey::HIP, DispatchKey::VE, DispatchKey::FPGA, - DispatchKey::MSNPU, + DispatchKey::ORT, DispatchKey::Vulkan, DispatchKey::Metal, DispatchKey::QuantizedCPU, diff --git a/c10/core/GradMode.cpp b/c10/core/GradMode.cpp index 32747a6698afa..c2ea8698732d7 100644 --- a/c10/core/GradMode.cpp +++ b/c10/core/GradMode.cpp @@ -4,13 +4,11 @@ namespace c10 { -thread_local bool GradMode_enabled = true; - bool GradMode::is_enabled() { - return GradMode_enabled; + return AutogradState::get_tls_state().get_grad_mode(); } void GradMode::set_enabled(bool enabled) { - GradMode_enabled = enabled; + AutogradState::get_tls_state().set_grad_mode(enabled); } } // namespace c10 diff --git a/c10/core/GradMode.h b/c10/core/GradMode.h index 1168bb1ae67c3..d83ff6d0d0d3b 100644 --- a/c10/core/GradMode.h +++ b/c10/core/GradMode.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace c10 { @@ -27,4 +28,17 @@ struct TORCH_API NoGradGuard : public AutoGradMode { NoGradGuard() : AutoGradMode(/*enabled=*/false) {} }; +// A RAII, thread local (!) guard that enables or disables forward grad mode +// upon construction, and sets it back to the original value upon destruction. +struct TORCH_API AutoFwGradMode { + AutoFwGradMode(bool enabled) + : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { + AutogradState::get_tls_state().set_fw_grad_mode(enabled); + } + ~AutoFwGradMode() { + AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); + } + bool prev_mode; +}; + } // namespace c10 diff --git a/c10/core/InferenceMode.cpp b/c10/core/InferenceMode.cpp index b588ab4da54b5..59eca760cf504 100644 --- a/c10/core/InferenceMode.cpp +++ b/c10/core/InferenceMode.cpp @@ -2,18 +2,12 @@ #include namespace c10 { -thread_local bool InferenceMode_enabled = false; - // Invariant: // is_enabled() == // !c10::impl::tls_is_dispatch_key_included(DispatchKey::ADInplaceOrView); // InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor) // so it worths a separate TLS to skip the DispatchKeySet check. bool InferenceMode::is_enabled() { - return InferenceMode_enabled; -} - -void InferenceMode::_set_enabled(bool enabled) { - InferenceMode_enabled = enabled; + return AutogradState::get_tls_state().get_inference_mode(); } } // namespace c10 diff --git a/c10/core/InferenceMode.h b/c10/core/InferenceMode.h index 7a9c2c593a453..704c43b522c6d 100644 --- a/c10/core/InferenceMode.h +++ b/c10/core/InferenceMode.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -50,10 +51,14 @@ struct TORCH_API InferenceMode { // are applicable to InferenceMode as well, e.g. // `tensorTypeInCurrentExecutionContext` in interpreter.cpp. InferenceMode(bool enabled = true) - : prev_mode(InferenceMode::is_enabled()), - prev_keyset(c10::impl::tls_local_dispatch_key_set()), - grad_mode(at::AutoGradMode(!enabled)) { - _set_enabled(enabled); + : prev_mode(AutogradState::get_tls_state()), + prev_keyset(c10::impl::tls_local_dispatch_key_set()) { + // Enabling inference mode means disabling grad modes + // And disabling inference mode means enabling grad modes + AutogradState::set_tls_state(AutogradState( + /* grad_mode */ !enabled, + /* inference_mode */ enabled, + /* fw_grad_mode */ !enabled)); DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); @@ -67,17 +72,13 @@ struct TORCH_API InferenceMode { } ~InferenceMode() { - _set_enabled(prev_mode); + AutogradState::set_tls_state(prev_mode); c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); } static bool is_enabled(); - // _set_enabled() is not user facing and should be only used in - // ThreadLocalState.cpp. - static void _set_enabled(bool enabled); private: - bool prev_mode; + AutogradState prev_mode; c10::impl::LocalDispatchKeySet prev_keyset; - at::AutoGradMode grad_mode; }; } // namespace c10 diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 44168ebca4360..f37ceb18a835d 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -3,7 +3,7 @@ #include #include -#include +#include namespace c10 { enum class Layout : int8_t { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index ba4e056e1e6c8..8cafde1b5c5e7 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -4,7 +4,7 @@ #include #include -#include +#include // Memory format is not the property of a Tensor. It is the way to tell an // operator how the result should be organized in memory and nothing more. That diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index d652db5a215c6..f7b07100365fa 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -12,7 +12,7 @@ #include #include -#include +#include namespace c10 { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index ff29b68dc4dad..bea717d7ee50f 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -68,7 +68,7 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target { StorageImpl() = delete; StorageImpl(StorageImpl&& other) = default; StorageImpl(const StorageImpl&) = delete; - ~StorageImpl() = default; + ~StorageImpl() override = default; void reset() { data_ptr_.clear(); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 65d7af38e3599..7051e36b35516 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -873,6 +873,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return key_set_.has(DispatchKey::MLC); } + bool is_ort() const { + return key_set_.has(DispatchKey::ORT); + } + // TODO: remove this once we don't automatically enabled Autograd dispatch // keys // in TensorImpl constructor. diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index fff9433e270f7..287b2fa41b2a3 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -663,8 +663,8 @@ inline DispatchKey computeDispatchKey( return DispatchKey::VE; case DeviceType::FPGA: return DispatchKey::FPGA; - case DeviceType::MSNPU: - return DispatchKey::MSNPU; + case DeviceType::ORT: + return DispatchKey::ORT; case DeviceType::XLA: return DispatchKey::XLA; case DeviceType::Lazy: @@ -790,10 +790,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) { case DispatchKey::HPU: case DispatchKey::AutogradHPU: return DeviceType::HPU; - - // stuff that isn't real - case DispatchKey::MSNPU: - return DeviceType::MSNPU; + case DispatchKey::ORT: + return DeviceType::ORT; default: TORCH_CHECK( false, diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 055375352ee08..659fea351d467 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -308,6 +308,8 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { } else { // It's ok to capture cudaMallocs, as long as we never cudaFree those // addresses before replay. + // Capturing cudaMalloc behaves nicely: it gives the graph new VA, + // but is ignored (won't leakily allocate new memory) in replays. at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; return cudaMalloc(p, size); } diff --git a/c10/mobile/CPUProfilingAllocator.h b/c10/mobile/CPUProfilingAllocator.h index 5112691a62d6f..bb080d9df97c3 100644 --- a/c10/mobile/CPUProfilingAllocator.h +++ b/c10/mobile/CPUProfilingAllocator.h @@ -50,7 +50,7 @@ class C10_API AllocationPlanner { private: AllocationPlan* allocation_plan_{nullptr}; // Maps allocated ptr to its allocation id. - // This is used when freeing the memory to lookup the allocation id + // This is used when freeing the memory to look up the allocation id // in order to establish the lifetime of a particular allocation. ska::flat_hash_map allocation_ptr_to_id_; uint64_t allocation_id_{0}; diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp index 1e34377282898..ac976b4b16f79 100644 --- a/c10/test/util/optional_test.cpp +++ b/c10/test/util/optional_test.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -8,6 +9,14 @@ namespace { +using testing::Eq; +using testing::Ge; +using testing::Gt; +using testing::Le; +using testing::Lt; +using testing::Ne; +using testing::Not; + template class OptionalTest : public ::testing::Test { public: @@ -90,4 +99,87 @@ TYPED_TEST(OptionalTest, Initialized) { } } +class SelfCompareTest : public testing::TestWithParam> {}; + +TEST_P(SelfCompareTest, SelfCompare) { + c10::optional x = GetParam(); + EXPECT_THAT(x, Eq(x)); + EXPECT_THAT(x, Le(x)); + EXPECT_THAT(x, Ge(x)); + EXPECT_THAT(x, Not(Ne(x))); + EXPECT_THAT(x, Not(Lt(x))); + EXPECT_THAT(x, Not(Gt(x))); +} + +INSTANTIATE_TEST_CASE_P( + nullopt, + SelfCompareTest, + testing::Values(c10::nullopt)); +INSTANTIATE_TEST_CASE_P( + int, + SelfCompareTest, + testing::Values(c10::make_optional(2))); + +TEST(OptionalTest, Nullopt) { + c10::optional x = 2; + + EXPECT_THAT(c10::nullopt, Not(Eq(x))); + EXPECT_THAT(x, Not(Eq(c10::nullopt))); + + EXPECT_THAT(x, Ne(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Ne(x)); + + EXPECT_THAT(x, Not(Lt(c10::nullopt))); + EXPECT_THAT(c10::nullopt, Lt(x)); + + EXPECT_THAT(x, Not(Le(c10::nullopt))); + EXPECT_THAT(c10::nullopt, Le(x)); + + EXPECT_THAT(x, Gt(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Not(Gt(x))); + + EXPECT_THAT(x, Ge(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Not(Ge(x))); +} + +// Ensure comparisons work... +using CmpTestTypes = testing::Types< + // between two optionals + std::pair, c10::optional>, + + // between an optional and a value + std::pair, int>, + // between a value and an optional + std::pair>, + + // between an optional and a differently typed value + std::pair, long>, + // between a differently typed value and an optional + std::pair>>; +template +class CmpTest : public testing::Test {}; +TYPED_TEST_CASE(CmpTest, CmpTestTypes); + +TYPED_TEST(CmpTest, Cmp) { + TypeParam pair = {2, 3}; + auto x = pair.first; + auto y = pair.second; + + EXPECT_THAT(x, Not(Eq(y))); + + EXPECT_THAT(x, Ne(y)); + + EXPECT_THAT(x, Lt(y)); + EXPECT_THAT(y, Not(Lt(x))); + + EXPECT_THAT(x, Le(y)); + EXPECT_THAT(y, Not(Le(x))); + + EXPECT_THAT(x, Not(Gt(y))); + EXPECT_THAT(y, Gt(x)); + + EXPECT_THAT(x, Not(Ge(y))); + EXPECT_THAT(y, Ge(x)); +} + } // namespace diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h index 2760100db6e98..a7b8426ced36a 100644 --- a/c10/util/BFloat16-math.h +++ b/c10/util/BFloat16-math.h @@ -57,6 +57,12 @@ inline c10::BFloat16 sin(c10::BFloat16 a) { inline c10::BFloat16 tan(c10::BFloat16 a) { return std::tan(float(a)); } +inline c10::BFloat16 sinh(c10::BFloat16 a) { + return std::sinh(float(a)); +} +inline c10::BFloat16 cosh(c10::BFloat16 a) { + return std::cosh(float(a)); +} inline c10::BFloat16 tanh(c10::BFloat16 a) { return std::tanh(float(a)); } diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index d978f32cd00e0..2c5e2e4cdca16 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -16,8 +16,13 @@ #if SUPPORTS_BACKTRACE #include +#ifdef C10_ANDROID +#include +#include +#else #include #endif +#endif #ifdef FBCODE_CAFFE2 #include @@ -25,6 +30,59 @@ namespace c10 { +#if SUPPORTS_BACKTRACE && defined(C10_ANDROID) + +struct AndroidBacktraceState { + std::vector buffer; +}; + +_Unwind_Reason_Code android_unwind_callback( + struct _Unwind_Context* context, + void* arg) { + AndroidBacktraceState* state = (AndroidBacktraceState*)arg; + uintptr_t pc = _Unwind_GetIP(context); + if (pc) { + state->buffer.emplace_back(reinterpret_cast(pc)); + } + return _URC_NO_REASON; +} + +void dump_stack( + std::ostream& os, + size_t frames_to_skip, + size_t maximum_number_of_frames) { + AndroidBacktraceState state; + + _Unwind_Backtrace(android_unwind_callback, &state); + + int idx = 0; + char* demangled = nullptr; + size_t length = 0; + + for (const void* addr : state.buffer) { + const char* symbol = ""; + + Dl_info info; + if (dladdr(addr, &info) && info.dli_sname) { + symbol = info.dli_sname; + } + + int status = 0; + demangled = __cxxabiv1::__cxa_demangle( + /*mangled_name*/ symbol, + /*output_buffer*/ demangled, + /*length*/ &length, + /*status*/ &status); + + os << " frame #" << idx++ << "\t" + << ((demangled != NULL && status == 0) ? demangled : symbol) << "[" + << addr << "]\t" << std::endl; + } + free(demangled); +} + +#endif /* SUPPORTS_BACKTRACE && defined(C10_ANDROID) */ + #if SUPPORTS_BACKTRACE namespace { @@ -42,6 +100,7 @@ struct FrameInformation { std::string object_file; }; +#ifndef C10_ANDROID bool is_python_frame(const FrameInformation& frame) { return frame.object_file == "python" || frame.object_file == "python3" || (frame.object_file.find("libpython") != std::string::npos); @@ -113,6 +172,7 @@ c10::optional parse_frame_information( frame.function_name = demangle(mangled_function_name.c_str()); return frame; } +#endif /* !defined(C10_ANDROID) */ } // anonymous namespace #elif defined(_MSC_VER) namespace { @@ -178,7 +238,7 @@ std::string get_backtrace( facebook::process::StackTrace st; return st.toString(); -#elif SUPPORTS_BACKTRACE +#elif SUPPORTS_BACKTRACE && !defined(C10_ANDROID) // We always skip this frame (backtrace). frames_to_skip += 1; @@ -249,6 +309,13 @@ std::string get_backtrace( } return stream.str(); + +#elif SUPPORTS_BACKTRACE && defined(C10_ANDROID) + + std::ostringstream oss; + dump_stack(oss, frames_to_skip, maximum_number_of_frames); + return oss.str().c_str(); + #elif defined(_MSC_VER) // !SUPPORTS_BACKTRACE // This backtrace retrieval is implemented on Windows via the Windows // API using `CaptureStackBackTrace`, `SymFromAddr` and diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h index 6f7c4b9a1d78b..bed04a438abea 100644 --- a/c10/util/Bitset.h +++ b/c10/util/Bitset.h @@ -3,7 +3,6 @@ #include #include #include -#include #if defined(_MSC_VER) #include #endif diff --git a/c10/util/Optional.h b/c10/util/Optional.h index 5e0684bb7d2f5..7044c798d2de4 100644 --- a/c10/util/Optional.h +++ b/c10/util/Optional.h @@ -1049,63 +1049,63 @@ constexpr bool operator>=(nullopt_t, const optional& x) noexcept { } // 20.5.10, Comparison with T -template -constexpr bool operator==(const optional& x, const T& v) { +template +constexpr bool operator==(const optional& x, const U& v) { return bool(x) ? *x == v : false; } -template -constexpr bool operator==(const T& v, const optional& x) { +template +constexpr bool operator==(const U& v, const optional& x) { return bool(x) ? v == *x : false; } -template -constexpr bool operator!=(const optional& x, const T& v) { +template +constexpr bool operator!=(const optional& x, const U& v) { return bool(x) ? *x != v : true; } -template -constexpr bool operator!=(const T& v, const optional& x) { +template +constexpr bool operator!=(const U& v, const optional& x) { return bool(x) ? v != *x : true; } -template -constexpr bool operator<(const optional& x, const T& v) { +template +constexpr bool operator<(const optional& x, const U& v) { return bool(x) ? *x < v : true; } -template -constexpr bool operator>(const T& v, const optional& x) { +template +constexpr bool operator>(const U& v, const optional& x) { return bool(x) ? v > *x : true; } -template -constexpr bool operator>(const optional& x, const T& v) { +template +constexpr bool operator>(const optional& x, const U& v) { return bool(x) ? *x > v : false; } -template -constexpr bool operator<(const T& v, const optional& x) { +template +constexpr bool operator<(const U& v, const optional& x) { return bool(x) ? v < *x : false; } -template -constexpr bool operator>=(const optional& x, const T& v) { +template +constexpr bool operator>=(const optional& x, const U& v) { return bool(x) ? *x >= v : false; } -template -constexpr bool operator<=(const T& v, const optional& x) { +template +constexpr bool operator<=(const U& v, const optional& x) { return bool(x) ? v <= *x : false; } -template -constexpr bool operator<=(const optional& x, const T& v) { +template +constexpr bool operator<=(const optional& x, const U& v) { return bool(x) ? *x <= v : true; } -template -constexpr bool operator>=(const T& v, const optional& x) { +template +constexpr bool operator>=(const U& v, const optional& x) { return bool(x) ? v >= *x : true; } diff --git a/c10/util/complex.h b/c10/util/complex.h index 2a565f8f2bf8f..67ed463febd94 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include diff --git a/c10/util/either.h b/c10/util/either.h index da765b9a9bb17..757663f5896fb 100644 --- a/c10/util/either.h +++ b/c10/util/either.h @@ -6,7 +6,6 @@ #include #include #include -#include namespace c10 { /** diff --git a/c10/util/typeid.h b/c10/util/typeid.h index e6a5822a3e7ce..240c69e92400e 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 619455421f282..8b403a7c4014e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -175,7 +175,7 @@ endif() if(BUILD_SPLIT_CUDA) # Splitting the source files that'll be in torch_cuda between torch_cuda_cu and torch_cuda_cpp foreach(tmp ${Caffe2_GPU_SRCS}) - if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|TensorShapeCUDA).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*") + if("${tmp}" MATCHES "(.*aten.*\\.cu|.*(b|B)las.*|.*((s|S)olver|Register.*CUDA|Legacy|THC|TensorShapeCUDA|BatchLinearAlgebra).*\\.cpp)" AND NOT "${tmp}" MATCHES ".*(THC((CachingHost)?Allocator|General)).*") # Currently, torch_cuda_cu will have all the .cu files in aten, as well as some others that depend on those files list(APPEND Caffe2_GPU_SRCS_CU ${tmp}) else() @@ -397,7 +397,9 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_3.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_4.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_variable_methods.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_0.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_1.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_2.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_fft_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_linalg_functions.cpp" @@ -529,11 +531,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp PROPERTIES COMPILE_FLAGS -Wno-init-list-lifetime) endif() - # Pass path to PocketFFT - if(AT_POCKETFFT_ENABLED) - set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/mkl/SpectralOps.cpp PROPERTIES INCLUDE_DIRECTORIES "${POCKETFFT_INCLUDE_DIR}") - endif() - if(NOT INTERN_DISABLE_MOBILE_INTERP) set(MOBILE_SRCS ${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp @@ -795,6 +792,17 @@ if(USE_PRECOMPILED_HEADERS) PROPERTIES SKIP_PRECOMPILE_HEADERS ON) endif() +# Pass path to PocketFFT +if(AT_POCKETFFT_ENABLED) + if(CMAKE_VERSION VERSION_LESS "3.11") + target_include_directories(torch_cpu PRIVATE "${POCKETFFT_INCLUDE_DIR}") + else() + set_source_files_properties( + "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/mkl/SpectralOps.cpp" + PROPERTIES INCLUDE_DIRECTORIES "${POCKETFFT_INCLUDE_DIR}") + endif() +endif() + if(CMAKE_COMPILER_IS_GNUCXX AND BUILD_LIBTORCH_CPU_WITH_DEBUG) # To enable debug fission we need to build libtorch_cpu with debug info on, # but this increases link time and peak memory usage if we use the @@ -1042,27 +1050,10 @@ if(USE_TBB) target_link_libraries(torch_cpu PUBLIC TBB::tbb) endif() - -if(LINUX) - find_library(BREAKPAD_LIB breakpad_client - PATHS /usr/local/lib/) - find_path(BREAKPAD_INCLUDE_DIR breakpad - PATHS /usr/local/include/) - - if(BREAKPAD_LIB AND BREAKPAD_INCLUDE_DIR) - message(STATUS "found breakpad library") - target_link_libraries(torch_cpu PRIVATE ${BREAKPAD_LIB}) - target_compile_definitions(torch_cpu PRIVATE ADD_BREAKPAD_SIGNAL_HANDLER) - target_include_directories(torch_cpu PRIVATE ${BREAKPAD_INCLUDE_DIR}/breakpad) - else() - if(BREAKPAD_INCLUDE_DIR) - message(STATUS "breakpad_client library not found") - elseif(BREAKPAD_LIB) - message(STATUS "breakpad include path not found") - else() - message(STATUS "breakpad_client library and include path not found") - endif() - endif() +if(USE_BREAKPAD) + target_compile_definitions(torch_cpu PRIVATE ADD_BREAKPAD_SIGNAL_HANDLER) + target_include_directories(torch_cpu PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../third_party ${CMAKE_CURRENT_LIST_DIR}/../third_party/breakpad/src) + target_link_libraries(torch_cpu PRIVATE breakpad) endif() target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) diff --git a/caffe2/core/init.cc b/caffe2/core/init.cc index 529665869b3e1..bafbc825f8b79 100644 --- a/caffe2/core/init.cc +++ b/caffe2/core/init.cc @@ -3,6 +3,7 @@ #include "caffe2/core/scope_guard.h" #include +#include #include C10_DEFINE_bool( diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc index 846ab8ab55b46..e25c92a6d6075 100644 --- a/caffe2/core/operator.cc +++ b/caffe2/core/operator.cc @@ -1,6 +1,7 @@ #include "caffe2/core/operator.h" #include +#include #include "caffe2/core/init.h" #include "caffe2/core/logging.h" @@ -355,6 +356,17 @@ void SetOpEnginePref( } } +DeviceTypeRegisterer::DeviceTypeRegisterer(DeviceType type, RegistryFunction func) { + if (gDeviceTypeRegistry()->count(type)) { + std::cerr << "Device type " << DeviceTypeName(type) + << "registered twice. This should not happen. Did you have " + "duplicated numbers assigned to different devices?"; + std::exit(1); + } + // Calling the registry function to get the actual registry pointer. + gDeviceTypeRegistry()->emplace(type, func()); +} + unique_ptr CreateOperator( const OperatorDef& operator_def, Workspace* ws, @@ -819,7 +831,7 @@ std::function GetOperatorLogger() { } c10::optional OperatorBase::argumentIndexWithName( - const std::string& name) const { + c10::string_view name) const { #if defined(EXPOSE_C2_OPS) || \ !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) return getFunctionSchema().argumentIndexWithName(name); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index fc9a6769c4e65..15d1ead352762 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include "caffe2/core/blob.h" @@ -97,7 +98,7 @@ class TORCH_API OperatorBase : public Observable { /** @brief Checks if the operator has an argument of the given name. */ - inline bool HasArgument(const string& name) const { + inline bool HasArgument(c10::string_view name) const { if (isLegacyOperator()) { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); return ArgumentHelper::HasArgument(*operator_def_, name); @@ -108,7 +109,7 @@ class TORCH_API OperatorBase : public Observable { // Functions that deal with arguments. Basically, this allows us to map an // argument name to a specific type of argument that we are trying to access. template - inline T GetSingleArgument(const string& name, const T& default_value) const { + inline T GetSingleArgument(c10::string_view name, const T& default_value) const { if (isLegacyOperator()) { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); return ArgumentHelper::GetSingleArgument( @@ -126,7 +127,7 @@ class TORCH_API OperatorBase : public Observable { } template - inline bool HasSingleArgumentOfType(const string& name) const { + inline bool HasSingleArgumentOfType(c10::string_view name) const { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); return ArgumentHelper::HasSingleArgumentOfType( *operator_def_, name); @@ -141,7 +142,7 @@ class TORCH_API OperatorBase : public Observable { template inline vector GetRepeatedArgument( - const string& name, + c10::string_view name, const vector& default_value = {}) const; // Get the inputs and outputs as specific types. @@ -654,7 +655,7 @@ class TORCH_API OperatorBase : public Observable { } } - c10::optional argumentIndexWithName(const std::string& name) const; + c10::optional argumentIndexWithName(c10::string_view name) const; // An event used by asynchronous execution. std::unique_ptr event_; @@ -664,7 +665,7 @@ class TORCH_API OperatorBase : public Observable { template <> inline NetDef OperatorBase::GetSingleArgument( - const std::string& name, + c10::string_view name, const NetDef& default_value) const { if (isLegacyOperator()) { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); @@ -756,7 +757,7 @@ inline vector OperatorBase::GetVectorFromIValueList( template inline vector OperatorBase::GetRepeatedArgument( - const string& name, + c10::string_view name, const vector& default_value) const { if (isLegacyOperator()) { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); @@ -778,7 +779,7 @@ inline vector OperatorBase::GetRepeatedArgument( // int16_t. We need to load it as List and transform to int16_t. template <> inline vector OperatorBase::GetRepeatedArgument( - const string& name, + c10::string_view name, const vector& default_value) const { if (isLegacyOperator()) { CAFFE_ENFORCE(operator_def_, "operator_def was null!"); @@ -1330,16 +1331,7 @@ typedef c10::Registry< TORCH_API std::map* gDeviceTypeRegistry(); struct TORCH_API DeviceTypeRegisterer { - explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func) { - if (gDeviceTypeRegistry()->count(type)) { - std::cerr << "Device type " << DeviceTypeName(type) - << "registered twice. This should not happen. Did you have " - "duplicated numbers assigned to different devices?"; - std::exit(1); - } - // Calling the registry function to get the actual registry pointer. - gDeviceTypeRegistry()->emplace(type, func()); - } + explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func); }; #if defined(_MSC_VER) diff --git a/caffe2/core/operator_schema.cc b/caffe2/core/operator_schema.cc index fbfb8f404d359..29d0b3e78d9a4 100644 --- a/caffe2/core/operator_schema.cc +++ b/caffe2/core/operator_schema.cc @@ -1,6 +1,8 @@ #include "caffe2/core/operator_schema.h" #include "caffe2/core/logging.h" +#include + #include namespace caffe2 { @@ -520,6 +522,22 @@ C10_EXPORT std::ostream& operator<<(std::ostream& out, const OpSchema& schema) { return out; } +OpSchema& OpSchemaRegistry::NewSchema(const string& key, const string& file, const int line) { + auto& m = map(); + auto it = m.find(key); + if (it != m.end()) { + const auto& schema = it->second; + std::ios_base::Init init; + std::cerr << "Trying to register schema with name " << key + << " from file " << file << " line " << line + << ", but it is already registered from file " << schema.file() + << " line " << schema.line(); + abort(); + } + m.emplace(key, OpSchema(key, file, line)); + return m[key]; +} + CaffeMap& OpSchemaRegistry::map() { static CaffeMap map; return map; diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h index b19d5be079af2..0d048eb8d26e9 100644 --- a/caffe2/core/operator_schema.h +++ b/caffe2/core/operator_schema.h @@ -6,12 +6,13 @@ #include #include #include -#include #include +#include #include "c10/util/Registry.h" #include "caffe2/core/common.h" #include "caffe2/core/logging.h" +#include "caffe2/core/types.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/filler.h" #include "caffe2/utils/proto_utils.h" @@ -273,8 +274,8 @@ class TORCH_API OpSchema { OpSchema& Arg(const char* name, const char* description, bool required = false); -#define DECLARE_STANDARD_ARG(name, str) \ - static const char* Arg_##name; \ +#define DECLARE_STANDARD_ARG(name, str) \ + static const char* Arg_##name; \ OpSchema& Arg##name(const char* description); DECLARE_STANDARD_ARG(IsTest, is_test) @@ -339,7 +340,9 @@ class TORCH_API OpSchema { return inplace_enforced_(x, y); } - TORCH_API friend std::ostream& operator<<(std::ostream& out, const OpSchema& schema); + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const OpSchema& schema); const std::vector& args() const { return args_; @@ -460,21 +463,7 @@ class TORCH_API OpSchema { class TORCH_API OpSchemaRegistry { public: static OpSchema& - NewSchema(const string& key, const string& file, const int line) { - auto& m = map(); - auto it = m.find(key); - if (it != m.end()) { - const auto& schema = it->second; - std::ios_base::Init init; - std::cerr << "Trying to register schema with name " << key - << " from file " << file << " line " << line - << ", but it is already registered from file " << schema.file() - << " line " << schema.line(); - abort(); - } - m.emplace(key, OpSchema(key, file, line)); - return m[key]; - } + NewSchema(const string& key, const string& file, const int line); static const OpSchema* Schema(const string& key) { auto& m = map(); @@ -576,8 +565,10 @@ OpSchema::Cost PointwiseCostInference( } c.flops = nElemX * OpsPerPoint; - c.bytes_read = nElemRead * sizeof(X.data_type()); - c.bytes_written = nElemX * sizeof(X.data_type()); + auto const& X_element_size_byte = + DataTypeToTypeMeta(X.data_type()).itemsize(); + c.bytes_read = nElemRead * X_element_size_byte; + c.bytes_written = nElemX * X_element_size_byte; return c; } diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc index 6f0c237a8b086..7a54403805ecb 100644 --- a/caffe2/core/plan_executor_test.cc +++ b/caffe2/core/plan_executor_test.cc @@ -290,6 +290,8 @@ TEST(PlanExecutorTest, BlockingErrorPlan) { #endif #endif + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) ASSERT_DEATH( [] { diff --git a/caffe2/operators/batch_matmul_op.cc b/caffe2/operators/batch_matmul_op.cc index 32799ced10671..205acf74f1572 100644 --- a/caffe2/operators/batch_matmul_op.cc +++ b/caffe2/operators/batch_matmul_op.cc @@ -1,6 +1,7 @@ #include "caffe2/operators/batch_matmul_op.h" #include "caffe2/core/operator_schema.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -116,9 +117,13 @@ OpSchema::Cost CostInferenceForBatchMatMul( K = in[0].dims(ndims_A - 1); } + auto const& A_element_size_byte = + DataTypeToTypeMeta(A.data_type()).itemsize(); + auto const& Y_element_size_byte = + DataTypeToTypeMeta(Y.data_type()).itemsize(); c.flops = 2 * nElemY * K; - c.bytes_read = (nElemA + nElemB) * sizeof(A.data_type()); - c.bytes_written = nElemY * sizeof(Y.data_type()); + c.bytes_read = (nElemA + nElemB) * A_element_size_byte; + c.bytes_written = nElemY * Y_element_size_byte; c.params_bytes = 0; return c; } @@ -180,72 +185,76 @@ class GetBatchMatMulGradient : public GradientMakerBase { auto no_trans_arg = vector(); auto trans_a_arg = vector{MakeArgument("trans_a", 1)}; auto trans_b_arg = vector{MakeArgument("trans_b", 1)}; - auto trans_both_arg = vector{MakeArgument("trans_a", 1), - MakeArgument("trans_b", 1)}; + auto trans_both_arg = vector{ + MakeArgument("trans_a", 1), MakeArgument("trans_b", 1)}; if (trans_a) { if (trans_b) { // A'B': // dA = B'G', dB = G'A' - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{I(1), GO(0)}, - vector{GI(0)}, - trans_both_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(0)}, - vector{GI(1)}, - trans_both_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(1), GO(0)}, + vector{GI(0)}, + trans_both_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(0)}, + vector{GI(1)}, + trans_both_arg)}; } else { // A'B: // dA = BG', dB = AG - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{I(1), GO(0)}, - vector{GI(0)}, - trans_b_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{I(0), GO(0)}, - vector{GI(1)}, - no_trans_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(1), GO(0)}, + vector{GI(0)}, + trans_b_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(0), GO(0)}, + vector{GI(1)}, + no_trans_arg)}; } } else { if (trans_b) { // AB': // dA = GB, dB = G'A - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(1)}, - vector{GI(0)}, - no_trans_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(0)}, - vector{GI(1)}, - trans_a_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(1)}, + vector{GI(0)}, + no_trans_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(0)}, + vector{GI(1)}, + trans_a_arg)}; } else { // AB: // dA = GB', dB = A'G - return vector{CreateOperatorDef( - "BatchMatMul", - "", - vector{GO(0), I(1)}, - vector{GI(0)}, - trans_b_arg), - CreateOperatorDef( - "BatchMatMul", - "", - vector{I(0), GO(0)}, - vector{GI(1)}, - trans_a_arg)}; + return vector{ + CreateOperatorDef( + "BatchMatMul", + "", + vector{GO(0), I(1)}, + vector{GI(0)}, + trans_b_arg), + CreateOperatorDef( + "BatchMatMul", + "", + vector{I(0), GO(0)}, + vector{GI(1)}, + trans_a_arg)}; } } } diff --git a/caffe2/operators/concat_split_op.cc b/caffe2/operators/concat_split_op.cc index 8eceb5ab4a577..86d6536b8880d 100644 --- a/caffe2/operators/concat_split_op.cc +++ b/caffe2/operators/concat_split_op.cc @@ -101,9 +101,12 @@ OpSchema::Cost CostInferenceForSplit( CAFFE_ENFORCE_GT(in.size(), 0); struct OpSchema::Cost cost; cost.flops = 0; - auto input_bytes_count = nElemFromDim(in[0]) * sizeof(in[0].data_type()); - auto split_bytes_count = - (in.size() == 1) ? 0 : nElemFromDim(in[1]) * sizeof(in[1].data_type()); + auto const& input_0_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); + auto input_bytes_count = nElemFromDim(in[0]) * input_0_element_size_byte; + auto split_bytes_count = in.size() > 1 + ? nElemFromDim(in[1]) * DataTypeToTypeMeta(in[1].data_type()).itemsize() + : 0; // There can be two input blobs: // (1) actual tensor to be split // (2) lengths of outputs along split axis @@ -329,11 +332,13 @@ OpSchema::Cost CostInferenceForConcat( } auto split_info_bytes_count = in.size() * sizeof(int); + auto const& input_0_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); struct OpSchema::Cost cost; cost.flops = 0; - cost.bytes_read = nElemRead * sizeof(in[0].data_type()); + cost.bytes_read = nElemRead * input_0_element_size_byte; cost.bytes_written = - size * sizeof(in[0].data_type()) + split_info_bytes_count; + size * input_0_element_size_byte + split_info_bytes_count; cost.params_bytes = 0; return cost; } diff --git a/caffe2/operators/concat_split_op.h b/caffe2/operators/concat_split_op.h index bbe355e50420f..f1e8f10d4d3dc 100644 --- a/caffe2/operators/concat_split_op.h +++ b/caffe2/operators/concat_split_op.h @@ -282,7 +282,7 @@ bool ConcatOp::RunOnDevice() { // We can override default options(Context::GetDeviceType()) // by explicitly passing in device type we want Tensor* split = Output( - 1, std::vector(1, InputSize()), at::dtype().device(CPU)); + 1, at::IntArrayRef({InputSize()}), at::dtype().device(CPU)); int* axis_data = split->template mutable_data(); auto& input_zero = Input(0); int adj_size = input_zero.dim() + (add_axis_ ? 1 : 0); diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index 25bd99a92e50f..b356ef952d79c 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -7,6 +7,7 @@ #include "caffe2/core/context.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" +#include "caffe2/core/types.h" #include "caffe2/proto/caffe2_legacy.pb.h" #include "caffe2/utils/math.h" @@ -519,14 +520,20 @@ class ConvPoolOpBase : public Operator { uint64_t nElemW = nElemFromDim(W); uint64_t nElemBias = inputs.size() > 2 ? nElemFromDim(inputs[2]) : 0; + auto const& X_elemenet_size_byte = + DataTypeToTypeMeta(X.data_type()).itemsize(); + auto const& Y_element_size_byte = + DataTypeToTypeMeta(Y.data_type()).itemsize(); + auto const& W_element_size_byte = + DataTypeToTypeMeta(W.data_type()).itemsize(); + // grouping is NOT properly handled yet c.flops = N * Y_t * Y_h * Y_w * kernel_t * kernel_w * kernel_h * in_channels * out_channels * 2; - c.bytes_read = (nElemX + nElemW + nElemBias) * sizeof(X.data_type()); - c.bytes_written = - N * out_channels * Y_t * Y_h * Y_w * sizeof(Y.data_type()); + c.bytes_read = (nElemX + nElemW + nElemBias) * X_elemenet_size_byte; + c.bytes_written = N * out_channels * Y_t * Y_h * Y_w * Y_element_size_byte; c.params_bytes = out_channels * in_channels * kernel_t * kernel_h * - kernel_w * sizeof(W.data_type()); + kernel_w * W_element_size_byte; return c; } diff --git a/caffe2/operators/distance_op.cc b/caffe2/operators/distance_op.cc index 1529534d8fb2e..9ea8eea5a2725 100644 --- a/caffe2/operators/distance_op.cc +++ b/caffe2/operators/distance_op.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/distance_op.h" +#include "caffe2/core/types.h" #include "caffe2/utils/eigen_utils.h" #ifdef CAFFE2_USE_MKLDNN #include @@ -7,7 +8,7 @@ namespace caffe2 { -template<> +template <> bool SquaredL2DistanceOp::RunOnDevice() { auto& X = Input(0); auto& Y = Input(1); @@ -257,7 +258,9 @@ OpSchema::Cost CostInferenceForDotProduct( CAFFE_ENFORCE_EQ(out[0].dims().size(), 1); struct OpSchema::Cost c = PointwiseCostInference<2>(def, in); - c.bytes_written = out[0].dims(0) * sizeof(out[0].data_type()); + auto const& out_0_element_size_byte = + DataTypeToTypeMeta(out[0].data_type()).itemsize(); + c.bytes_written = out[0].dims(0) * out_0_element_size_byte; c.params_bytes = 0; return c; } @@ -379,10 +382,12 @@ bool DotProductWithPaddingOp::RunOnDevice() { } // L2 -REGISTER_CPU_OPERATOR(SquaredL2Distance, - SquaredL2DistanceOp); -REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient, - SquaredL2DistanceGradientOp); +REGISTER_CPU_OPERATOR( + SquaredL2Distance, + SquaredL2DistanceOp); +REGISTER_CPU_OPERATOR( + SquaredL2DistanceGradient, + SquaredL2DistanceGradientOp); OPERATOR_SCHEMA(SquaredL2Distance) .NumInputs(2) @@ -402,7 +407,8 @@ class GetSquaredL2DistanceGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; vector GetGradientDefs() override { return SingleGradientDef( - "SquaredL2DistanceGradient", "", + "SquaredL2DistanceGradient", + "", vector{I(0), I(1), GO(0)}, vector{GI(0), GI(1)}); } @@ -762,9 +768,9 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase { replicate = GetArgument(Def(), "replicate").i(); } - const auto dot_arg = - vector{MakeArgument("pad_value", pad_value), - MakeArgument("replicate", replicate)}; + const auto dot_arg = vector{ + MakeArgument("pad_value", pad_value), + MakeArgument("replicate", replicate)}; return SingleGradientDef( "DotProductWithPaddingGradient", @@ -775,4 +781,4 @@ class GetDotProductWithPaddingGradient : public GradientMakerBase { } }; REGISTER_GRADIENT(DotProductWithPadding, GetDotProductWithPaddingGradient); -} // namespace caffe2 +} // namespace caffe2 diff --git a/caffe2/operators/fc_inference.cc b/caffe2/operators/fc_inference.cc index a44c230980c7f..ba1b7122cdc9d 100644 --- a/caffe2/operators/fc_inference.cc +++ b/caffe2/operators/fc_inference.cc @@ -1,4 +1,5 @@ #include "caffe2/operators/fc_inference.h" +#include "caffe2/core/types.h" namespace caffe2 { std::vector FCShapeInference( @@ -51,11 +52,12 @@ OpSchema::Cost CostInferenceForFC( ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1])) : size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); - const auto& X = in[0]; + auto const& X_element_size_byte = + DataTypeToTypeMeta(in[0].data_type()).itemsize(); c.flops = M * N * (2 * K + 1); - c.bytes_read = (K * (M + N) + N) * sizeof(X.data_type()); - c.bytes_written = M * N * sizeof(X.data_type()); - c.params_bytes = (K * N + N) * sizeof(X.data_type()); + c.bytes_read = (K * (M + N) + N) * X_element_size_byte; + c.bytes_written = M * N * X_element_size_byte; + c.params_bytes = (K * N + N) * X_element_size_byte; return c; } @@ -94,7 +96,11 @@ OpSchema::Cost CostInferenceForFCGradient( CAFFE_ENFORCE_LT(0, out.size()); const TensorShape dW = out[0]; + auto const& dW_element_size_byte = + DataTypeToTypeMeta(dW.data_type()).itemsize(); const TensorShape db = out[1]; + auto const& db_element_size_byte = + DataTypeToTypeMeta(db.data_type()).itemsize(); auto axis = helper.GetSingleArgument("axis", 1); const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size()); @@ -111,15 +117,17 @@ OpSchema::Cost CostInferenceForFCGradient( uint64_t size_db = nElemFromDim(db); c.flops = M * N * (2 * K + 1); - c.bytes_written = (size_dW + size_db) * sizeof(float); + c.bytes_written = + size_dW * dW_element_size_byte + size_db * db_element_size_byte; c.params_bytes = (K * N + N) * sizeof(float); if (out.size() == 3) { const TensorShape dX = out[2]; uint64_t size_dX = nElemFromDim(dX); - + auto const& dX_element_size_byte = + DataTypeToTypeMeta(dX.data_type()).itemsize(); c.flops += 2 * M * N * K; - c.bytes_written += size_dX * sizeof(float); + c.bytes_written += size_dX * dX_element_size_byte; } return c; } diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc index c3eaf05db0e8f..55c73a5be22c4 100644 --- a/caffe2/operators/one_hot_ops.cc +++ b/caffe2/operators/one_hot_ops.cc @@ -2,6 +2,7 @@ #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -78,12 +79,21 @@ OpSchema::Cost CostInferenceForBatchOneHot( const auto& length = in[1]; const auto& values = in[2]; - uint64_t nBytesData = nElemFromDim(data) * sizeof(data.data_type()); - uint64_t nBytesLength = nElemFromDim(length) * sizeof(length.data_type()); - uint64_t nBytesValues = nElemFromDim(values) * sizeof(values.data_type()); + auto const& data_element_size_byte = + DataTypeToTypeMeta(data.data_type()).itemsize(); + auto const& length_element_size_byte = + DataTypeToTypeMeta(length.data_type()).itemsize(); + auto const& values_element_size_byte = + DataTypeToTypeMeta(values.data_type()).itemsize(); + auto const& output_element_size_byte = + DataTypeToTypeMeta(output.data_type()).itemsize(); + + uint64_t nBytesData = nElemFromDim(data) * data_element_size_byte; + uint64_t nBytesLength = nElemFromDim(length) * length_element_size_byte; + uint64_t nBytesValues = nElemFromDim(values) * values_element_size_byte; c.flops = 0; c.bytes_read = nBytesData + nBytesLength + nBytesValues; - c.bytes_written = nElemFromDim(output) * sizeof(output.data_type()); + c.bytes_written = nElemFromDim(output) * output_element_size_byte; c.params_bytes = 0; return c; } @@ -145,15 +155,15 @@ bool BatchBucketOneHotOp::RunOnDevice() { for (int64_t j = 0; j < D; j++) { // here we assume the boundary values for each feature are sorted int64_t lower_bucket_idx = std::lower_bound( - boundaries_offset, - boundaries_offset + lens_data[j], - input_data[pos]) - + boundaries_offset, + boundaries_offset + lens_data[j], + input_data[pos]) - boundaries_offset; int64_t upper_bucket_idx = std::upper_bound( - boundaries_offset, - boundaries_offset + lens_data[j], - input_data[pos]) - + boundaries_offset, + boundaries_offset + lens_data[j], + input_data[pos]) - boundaries_offset; int64_t bucket_idx = (lower_bucket_idx + upper_bucket_idx) / 2; diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc index 7b2a02fae696b..561da9189b388 100644 --- a/caffe2/operators/utility_ops.cc +++ b/caffe2/operators/utility_ops.cc @@ -1,5 +1,7 @@ #include "caffe2/operators/utility_ops.h" #include +#include +#include "caffe2/core/types.h" #include "caffe2/utils/eigen_utils.h" namespace caffe2 { @@ -33,9 +35,11 @@ OpSchema::Cost CostInferenceForWeightedSum( const auto& nElem = nElemFromDim(X0); const auto& nInputs = in.size(); c.flops = (nInputs - 1) * nElem; - c.bytes_read = (nInputs / 2) * (nElem + 1) * sizeof(X0.data_type()); - c.bytes_written = nElem * sizeof(X0.data_type()); - c.params_bytes = (nInputs / 2) * sizeof(X0.data_type()); + auto const& X0_element_size_byte = + DataTypeToTypeMeta(X0.data_type()).itemsize(); + c.bytes_read = (nInputs / 2) * (nElem + 1) * X0_element_size_byte; + c.bytes_written = nElem * X0_element_size_byte; + c.params_bytes = (nInputs / 2) * X0_element_size_byte; return c; } @@ -47,9 +51,7 @@ REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp); REGISTER_CPU_OPERATOR(SumInt, SumOp); REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp); REGISTER_CPU_OPERATOR(WeightedSumGradient, WeightedSumGradientOp); -REGISTER_CPU_OPERATOR( - ScatterWeightedSum, - ScatterWeightedSumOp); +REGISTER_CPU_OPERATOR(ScatterWeightedSum, ScatterWeightedSumOp); REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp); REGISTER_CPU_OPERATOR(Scatter, ScatterOp); diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 6e055778578ab..90a2020195f60 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -219,7 +219,7 @@ enum DeviceTypeProto { PROTO_IDEEP = 5; // IDEEP. PROTO_HIP = 6; // AMD HIP PROTO_FPGA = 7; // FPGA - PROTO_MSNPU = 8; // MSNPU + PROTO_ORT = 8; // ONNX Runtime PROTO_XLA = 9; // XLA / TPU PROTO_MLC = 10; // ML Compute // Change the following number if you add more devices in the code. diff --git a/caffe2/proto/caffe2_pb2.pyi b/caffe2/proto/caffe2_pb2.pyi index 1258664bee165..f7f4430d7b761 100644 --- a/caffe2/proto/caffe2_pb2.pyi +++ b/caffe2/proto/caffe2_pb2.pyi @@ -23,7 +23,7 @@ class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapp PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_HIP = DeviceTypeProto.V(6) PROTO_FPGA = DeviceTypeProto.V(7) - PROTO_MSNPU = DeviceTypeProto.V(8) + PROTO_ORT = DeviceTypeProto.V(8) PROTO_XLA = DeviceTypeProto.V(9) PROTO_MLC = DeviceTypeProto.V(10) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) @@ -37,7 +37,7 @@ PROTO_OPENCL = DeviceTypeProto.V(4) PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_HIP = DeviceTypeProto.V(6) PROTO_FPGA = DeviceTypeProto.V(7) -PROTO_MSNPU = DeviceTypeProto.V(8) +PROTO_ORT = DeviceTypeProto.V(8) PROTO_XLA = DeviceTypeProto.V(9) PROTO_MLC = DeviceTypeProto.V(10) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) diff --git a/caffe2/python/operator_test/concat_op_cost_test.py b/caffe2/python/operator_test/concat_op_cost_test.py index 996b330be4947..7dab4d6bd5d1f 100644 --- a/caffe2/python/operator_test/concat_op_cost_test.py +++ b/caffe2/python/operator_test/concat_op_cost_test.py @@ -7,33 +7,39 @@ class TestConcatOpCost(TestCase): def test_columnwise_concat(self): - workspace.ResetWorkspace() - workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) - workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=np.int32)) - concat_op = core.CreateOperator( - "Concat", - ["input_1", "input_2"], - ["output", "split_info"], - ) - workspace.RunOperatorOnce(concat_op) + def _test_columnwise_concat_for_type(dtype): + workspace.ResetWorkspace() + workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=dtype)) + concat_op = core.CreateOperator( + "Concat", + ["input_1", "input_2"], + ["output", "split_info"], + ) + workspace.RunOperatorOnce(concat_op) - output = workspace.FetchBlob("output") - self.assertTupleEqual(output.shape, (2, 4)) - np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]]) + output = workspace.FetchBlob("output") + self.assertTupleEqual(output.shape, (2, 4)) + np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]]) - flops, bytes_written, bytes_read = workspace.GetOperatorCost( - concat_op, concat_op.input - ) + flops, bytes_written, bytes_read = workspace.GetOperatorCost( + concat_op, concat_op.input + ) - self.assertEqual(flops, 0) - self.assertEqual( - bytes_read, - sum(workspace.FetchBlob(b).nbytes for b in concat_op.input), - ) - self.assertEqual( - bytes_written, - sum(workspace.FetchBlob(b).nbytes for b in concat_op.output), - ) + self.assertEqual(flops, 0) + self.assertEqual( + bytes_read, + sum(workspace.FetchBlob(b).nbytes for b in concat_op.input), + ) + self.assertEqual( + bytes_written, + sum(workspace.FetchBlob(b).nbytes for b in concat_op.output), + ) + + [ + _test_columnwise_concat_for_type(t) + for t in [np.int64, np.float, np.half, np.int8] + ] def test_split_then_concat(self): workspace.ResetWorkspace() diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index afb2065027075..1bf7b607e1b7e 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -60,7 +60,7 @@ def testGetOperatorCost(self): self.assertTupleEqual( op_cost, namedtuple("Cost", ["flops", "bytes_written", "bytes_read"])( - 1152, 256, 2084 + 1152, 256, 4168 ), ) diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc index 701270b566145..1fddce970a84f 100644 --- a/caffe2/serialize/file_adapter.cc +++ b/caffe2/serialize/file_adapter.cc @@ -1,7 +1,8 @@ #include "caffe2/serialize/file_adapter.h" #include +#include #include - +#include #include "caffe2/core/common.h" namespace caffe2 { @@ -10,7 +11,20 @@ namespace serialize { FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { fp_ = fopen(file_name.c_str(), "rb"); if (fp_ == nullptr) { - AT_ERROR("open file failed, file path: ", file_name); + char buf[1024]; + buf[0] = '\0'; +#if defined(_WIN32) && (defined(__MINGW32__) || defined(_MSC_VER)) + strerror_s(buf, sizeof(buf), errno); +#else + strerror_r(errno, buf, sizeof(buf)); +#endif + AT_ERROR( + "open file failed because of errno ", + errno, + " on fopen: ", + buf, + ", file path: ", + file_name); } } diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h index 8960d5535c885..680c288a15f2e 100644 --- a/caffe2/serialize/istream_adapter.h +++ b/caffe2/serialize/istream_adapter.h @@ -16,7 +16,7 @@ class TORCH_API IStreamAdapter final : public ReadAdapterInterface { size_t size() const override; size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") const override; - ~IStreamAdapter(); + ~IStreamAdapter() override; private: std::istream* istream_; diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 61c8c46666e67..ed5795841d1f9 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -85,7 +85,7 @@ static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, // we should support this model_version. For example, we provide a wrapper to // handle an updated operator. constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; -constexpr uint64_t kMaxSupportedBytecodeVersion = 0x6L; +constexpr uint64_t kMaxSupportedBytecodeVersion = 0x7L; } // namespace serialize } // namespace caffe2 diff --git a/caffe2/sgd/adagrad_op.cc b/caffe2/sgd/adagrad_op.cc index 0de50f03e62d5..0b6f604b48cdb 100644 --- a/caffe2/sgd/adagrad_op.cc +++ b/caffe2/sgd/adagrad_op.cc @@ -1,4 +1,5 @@ #include "adagrad_op.h" +#include "caffe2/core/types.h" namespace caffe2 { @@ -23,22 +24,30 @@ static OpSchema::Cost CostInferenceForAdagrad( // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = grad_size * 10; + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& lr_element_size_byte = + DataTypeToTypeMeta(lr.data_type()).itemsize(); uint64_t bytes_written = - grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type())); + grad_size * param_element_size_byte + moment_element_size_byte; if (output_size == 3) { // also need to output effective learning rate in this case // assume it's the same data type as lr - bytes_written += grad_size * sizeof(lr.data_type()); + bytes_written += grad_size * lr_element_size_byte; } else if (output_size == 4) { // also need to output effective learning rate and updates in this case // assume update is the same data type as param bytes_written += - grad_size * (sizeof(lr.data_type()) + sizeof(param.data_type())); + grad_size * (lr_element_size_byte + param_element_size_byte); } c.bytes_written = bytes_written; c.bytes_read = c.bytes_written + - grad_size * (sizeof(grad.data_type()) + sizeof(lr.data_type())); + grad_size * (grad_element_size_byte + lr_element_size_byte); return c; } @@ -102,10 +111,18 @@ static OpSchema::Cost CostInferenceForSparseAdagrad( // (optimistically count sqrt as one flop). // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = grad_size * 7; + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); c.bytes_written = - grad_size * (sizeof(param.data_type()) + sizeof(moment.data_type())); - c.bytes_read = c.bytes_written + grad_size * sizeof(grad.data_type()) + - n * sizeof(indices.data_type()); + grad_size * (param_element_size_byte + moment_element_size_byte); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& indices_element_size_byte = + DataTypeToTypeMeta(indices.data_type()).itemsize(); + c.bytes_read = c.bytes_written + grad_size * grad_element_size_byte + + n * indices_element_size_byte; return c; } @@ -153,6 +170,16 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad( OpSchema::Cost c; if (n > 0) { + auto const& param_element_size_byte = + DataTypeToTypeMeta(param.data_type()).itemsize(); + auto const& moment_element_size_byte = + DataTypeToTypeMeta(moment.data_type()).itemsize(); + auto const& grad_element_size_byte = + DataTypeToTypeMeta(grad.data_type()).itemsize(); + auto const& indices_element_size_byte = + DataTypeToTypeMeta(indices.data_type()).itemsize(); + auto const& lr_element_size_byte = + DataTypeToTypeMeta(lr.data_type()).itemsize(); auto block_size = grad_size / n; if (block_size == 1) { // +2: applying weight decay and add to grads @@ -161,22 +188,22 @@ static OpSchema::Cost CostInferenceForRowWiseSparseAdagrad( // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = n * 9; c.bytes_written = - n * (sizeof(param.data_type()) + sizeof(moment.data_type())); + n * (param_element_size_byte + moment_element_size_byte); c.bytes_read = c.bytes_written + n * - (sizeof(grad.data_type()) + sizeof(indices.data_type()) + - sizeof(lr.data_type())); + (grad_element_size_byte + indices_element_size_byte + + lr_element_size_byte); } else { // 5 per block (not counting index transforms) // 8 for each value of a block // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) c.flops = n * (5 + (block_size * 8)); - c.bytes_written = - n * sizeof(moment.data_type()) + n * block_size * (param.data_type()); + c.bytes_written = n * moment_element_size_byte + + n * block_size * param_element_size_byte; - c.bytes_read = c.bytes_written + n * (sizeof(lr.data_type())) + + c.bytes_read = c.bytes_written + n * lr_element_size_byte + 2 * n * block_size * - (sizeof(grad.data_type()) + sizeof(param.data_type())); + (grad_element_size_byte + param_element_size_byte); } } return c; diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc index d2aa59e02b63f..db379462e5347 100644 --- a/caffe2/utils/proto_utils.cc +++ b/caffe2/utils/proto_utils.cc @@ -323,8 +323,12 @@ C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) { } } -C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const { +C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const { +#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP return arg_map_.count(name); +#else + return arg_map_.count(std::string(name)); +#endif } namespace { @@ -364,18 +368,19 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) { T, fieldname, enforce_lossless_conversion) \ template <> \ C10_EXPORT T ArgumentHelper::GetSingleArgument( \ - const string& name, const T& default_value) const { \ - if (arg_map_.count(name) == 0) { \ + c10::string_view name, const T& default_value) const { \ + auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ + if (it == arg_map_.end()) { \ VLOG(1) << "Using default parameter value " << default_value \ << " for parameter " << name; \ return default_value; \ } \ CAFFE_ENFORCE( \ - arg_map_.at(name).has_##fieldname(), \ + it->second.has_##fieldname(), \ "Argument ", \ name, \ " does not have the right field: expected field " #fieldname); \ - auto value = arg_map_.at(name).fieldname(); \ + auto value = it->second.fieldname(); \ if (enforce_lossless_conversion) { \ auto supportsConversion = \ SupportsLosslessConversion(value); \ @@ -391,11 +396,12 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) { } \ template <> \ C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType( \ - const string& name) const { \ - if (arg_map_.count(name) == 0) { \ + c10::string_view name) const { \ + auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ + if (it == arg_map_.end()) { \ return false; \ } \ - return arg_map_.at(name).has_##fieldname(); \ + return it->second.has_##fieldname(); \ } INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) @@ -415,13 +421,14 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false) #define INSTANTIATE_GET_REPEATED_ARGUMENT( \ T, fieldname, enforce_lossless_conversion) \ template <> \ - C10_EXPORT std::vector ArgumentHelper::GetRepeatedArgument( \ - const string& name, const std::vector& default_value) const { \ - if (arg_map_.count(name) == 0) { \ + C10_EXPORT std::vector ArgumentHelper::GetRepeatedArgument( \ + c10::string_view name, const std::vector& default_value) const { \ + auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \ + if (it == arg_map_.end()) { \ return default_value; \ } \ - std::vector values; \ - for (const auto& v : arg_map_.at(name).fieldname()) { \ + std::vector values; \ + for (const auto& v : it->second.fieldname()) { \ if (enforce_lossless_conversion) { \ auto supportsConversion = \ SupportsLosslessConversion(v); \ @@ -531,7 +538,7 @@ C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) { // Return the argument index or -1 if it does not exist. C10_EXPORT int GetArgumentIndex( const google::protobuf::RepeatedPtrField& args, - const string& name) { + c10::string_view name) { int index = 0; for (const Argument& arg : args) { if (arg.name() == name) { @@ -544,7 +551,7 @@ C10_EXPORT int GetArgumentIndex( C10_EXPORT const Argument& GetArgument( const OperatorDef& def, - const string& name) { + c10::string_view name) { int index = GetArgumentIndex(def.arg(), name); if (index != -1) { return def.arg(index); @@ -557,7 +564,7 @@ C10_EXPORT const Argument& GetArgument( } } -C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) { +C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) { int index = GetArgumentIndex(def.arg(), name); if (index != -1) { return def.arg(index); @@ -572,7 +579,7 @@ C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) { C10_EXPORT const Argument* GetArgumentPtr( const OperatorDef& def, - const string& name) { + c10::string_view name) { int index = GetArgumentIndex(def.arg(), name); if (index != -1) { return &def.arg(index); @@ -583,7 +590,7 @@ C10_EXPORT const Argument* GetArgumentPtr( C10_EXPORT const Argument* GetArgumentPtr( const NetDef& def, - const string& name) { + c10::string_view name) { int index = GetArgumentIndex(def.arg(), name); if (index != -1) { return &def.arg(index); @@ -594,7 +601,7 @@ C10_EXPORT const Argument* GetArgumentPtr( C10_EXPORT bool GetFlagArgument( const google::protobuf::RepeatedPtrField& args, - const string& name, + c10::string_view name, bool default_value) { int index = GetArgumentIndex(args, name); if (index != -1) { @@ -609,13 +616,13 @@ C10_EXPORT bool GetFlagArgument( C10_EXPORT bool GetFlagArgument( const OperatorDef& def, - const string& name, + c10::string_view name, bool default_value) { return GetFlagArgument(def.arg(), name, default_value); } C10_EXPORT bool -GetFlagArgument(const NetDef& def, const string& name, bool default_value) { +GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) { return GetFlagArgument(def.arg(), name, default_value); } diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h index 57676982c7851..b5c6b312b3ab3 100644 --- a/caffe2/utils/proto_utils.h +++ b/caffe2/utils/proto_utils.h @@ -8,10 +8,18 @@ #endif // !CAFFE2_USE_LITE_PROTO #include +#include #include "caffe2/utils/proto_wrap.h" #include "caffe2/proto/caffe2_pb.h" +#ifndef C10_ANDROID +#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP +#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key) +#else +#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key)) +#endif + namespace caffe2 { using std::string; @@ -204,40 +212,40 @@ TORCH_API bool HasInput(const OperatorDef& op, const std::string& input); class C10_EXPORT ArgumentHelper { public: template - static bool HasArgument(const Def& def, const string& name) { + static bool HasArgument(const Def& def, c10::string_view name) { return ArgumentHelper(def).HasArgument(name); } template static T GetSingleArgument( const Def& def, - const string& name, + c10::string_view name, const T& default_value) { return ArgumentHelper(def).GetSingleArgument(name, default_value); } template - static bool HasSingleArgumentOfType(const Def& def, const string& name) { + static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) { return ArgumentHelper(def).HasSingleArgumentOfType(name); } template static std::vector GetRepeatedArgument( const Def& def, - const string& name, + c10::string_view name, const std::vector& default_value = std::vector()) { return ArgumentHelper(def).GetRepeatedArgument(name, default_value); } template - static MessageType GetMessageArgument(const Def& def, const string& name) { + static MessageType GetMessageArgument(const Def& def, c10::string_view name) { return ArgumentHelper(def).GetMessageArgument(name); } template static std::vector GetRepeatedMessageArgument( const Def& def, - const string& name) { + c10::string_view name) { return ArgumentHelper(def).GetRepeatedMessageArgument(name); } @@ -255,24 +263,25 @@ class C10_EXPORT ArgumentHelper { explicit ArgumentHelper(const OperatorDef& def); explicit ArgumentHelper(const NetDef& netdef); - bool HasArgument(const string& name) const; + bool HasArgument(c10::string_view name) const; template - T GetSingleArgument(const string& name, const T& default_value) const; + T GetSingleArgument(c10::string_view name, const T& default_value) const; template - bool HasSingleArgumentOfType(const string& name) const; + bool HasSingleArgumentOfType(c10::string_view name) const; template std::vector GetRepeatedArgument( - const string& name, + c10::string_view name, const std::vector& default_value = std::vector()) const; template - MessageType GetMessageArgument(const string& name) const { - CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name); + MessageType GetMessageArgument(c10::string_view name) const { + auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); + CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); MessageType message; - if (arg_map_.at(name).has_s()) { + if (it->second.has_s()) { CAFFE_ENFORCE( - message.ParseFromString(arg_map_.at(name).s()), + message.ParseFromString(it->second.s()), "Failed to parse content from the string"); } else { VLOG(1) << "Return empty message for parameter " << name; @@ -281,42 +290,47 @@ class C10_EXPORT ArgumentHelper { } template - std::vector GetRepeatedMessageArgument(const string& name) const { - CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name); - std::vector messages(arg_map_.at(name).strings_size()); + std::vector GetRepeatedMessageArgument(c10::string_view name) const { + auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); + CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name); + std::vector messages(it->second.strings_size()); for (int i = 0; i < messages.size(); ++i) { CAFFE_ENFORCE( - messages[i].ParseFromString(arg_map_.at(name).strings(i)), + messages[i].ParseFromString(it->second.strings(i)), "Failed to parse content from the string"); } return messages; } private: - std::map arg_map_; + std::map +#endif + > arg_map_; }; // **** Arguments Utils ***** // Helper methods to get an argument from OperatorDef or NetDef given argument // name. Throws if argument does not exist. -TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name); -TORCH_API const Argument& GetArgument(const NetDef& def, const string& name); +TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name); +TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name); // Helper methods to get an argument from OperatorDef or NetDef given argument // name. Returns nullptr if argument does not exist. -TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name); -TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name); +TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name); +TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name); // Helper methods to query a boolean argument flag from OperatorDef or NetDef // given argument name. If argument does not exist, return default value. // Throws if argument exists but the type is not boolean. TORCH_API bool GetFlagArgument( const OperatorDef& def, - const string& name, + c10::string_view name, bool default_value = false); TORCH_API bool GetFlagArgument( const NetDef& def, - const string& name, + c10::string_view name, bool default_value = false); TORCH_API Argument* GetMutableArgument( diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 2c19dae96c909..b3cc23ccac8f4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -242,14 +242,15 @@ endif() # --- [ PocketFFT set(AT_POCKETFFT_ENABLED 0) -if(NOT MKL_FOUND) +if(NOT AT_MKL_ENABLED) find_path(POCKETFFT_INCLUDE_DIR NAMES pocketfft_hdronly.h PATHS /usr/local/include - "$ENV{POCKETFFT_HOME}" + ENV POCKETFFT_HOME "${PROJECT_SOURCE_DIR}/third_party/pocketfft" ) - if(POCKETFFT_INCLUDE_DIR AND CMAKE_VERSION VERSION_GREATER "3.9") + if(POCKETFFT_INCLUDE_DIR) set(AT_POCKETFFT_ENABLED 1) + message(STATUS "Using pocketfft in directory: ${POCKETFFT_INCLUDE_DIR}") endif() endif() @@ -1881,6 +1882,10 @@ set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "") list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only) set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) +if(USE_BREAKPAD) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/breakpad) +endif() + # ---[ Kineto # edge profiler depends on KinetoProfiler but it only does cpu # profiling. Thus we dont need USE_CUDA/USE_ROCM diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index efac9e0dfa8e6..1497b0044a0b0 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -178,10 +178,12 @@ endif() # -to add all (including unused) symbols into the dynamic symbol # -table. We need this to get symbols when generating backtrace at # -runtime. -check_cxx_compiler_flag("-rdynamic" COMPILER_SUPPORTS_RDYNAMIC) -if(${COMPILER_SUPPORTS_RDYNAMIC}) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -rdynamic") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -rdynamic") +if(NOT MSVC) + check_cxx_compiler_flag("-rdynamic" COMPILER_SUPPORTS_RDYNAMIC) + if(${COMPILER_SUPPORTS_RDYNAMIC}) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -rdynamic") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -rdynamic") + endif() endif() # ---[ If we are using msvc, set no warning flags diff --git a/cmake/ProtoBuf.cmake b/cmake/ProtoBuf.cmake index d8a2c279aee47..8d7633c4ab037 100644 --- a/cmake/ProtoBuf.cmake +++ b/cmake/ProtoBuf.cmake @@ -196,7 +196,7 @@ function(caffe2_protobuf_generate_cpp_py srcs_var hdrs_var python_var) # If we remove all reference to these pb.h files from external # libraries and binaries this rewrite can be removed. - COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -DLOCAL_PROTOBUF=${CAFFE2_LINK_LOCAL_PROTOBUF} -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake + COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake DEPENDS ${CAFFE2_PROTOC_EXECUTABLE} ${abs_fil} COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM ) @@ -209,7 +209,7 @@ function(caffe2_protobuf_generate_cpp_py srcs_var hdrs_var python_var) COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}" COMMAND ${CAFFE2_PROTOC_EXECUTABLE} -I${PROJECT_SOURCE_DIR} --cpp_out=${DLLEXPORT_STR}${PROJECT_BINARY_DIR} ${abs_fil} COMMAND ${CAFFE2_PROTOC_EXECUTABLE} -I${PROJECT_SOURCE_DIR} --python_out "${PROJECT_BINARY_DIR}" ${abs_fil} - COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -DLOCAL_PROTOBUF=${CAFFE2_LINK_LOCAL_PROTOBUF} -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake + COMMAND ${CMAKE_COMMAND} -DFILENAME=${CMAKE_CURRENT_BINARY_DIR}/${fil_we}.pb.h -DNAMESPACES=caffe\;caffe2\;onnx\;torch -DSYSTEM_PROTOBUF=YES -P ${PROJECT_SOURCE_DIR}/cmake/ProtoBufPatch.cmake DEPENDS ${CAFFE2_PROTOC_EXECUTABLE} ${abs_fil} COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM ) endif() diff --git a/cmake/ProtoBufPatch.cmake b/cmake/ProtoBufPatch.cmake index 704dcd7da1545..7f1de9a4a1de9 100644 --- a/cmake/ProtoBufPatch.cmake +++ b/cmake/ProtoBufPatch.cmake @@ -4,7 +4,7 @@ file(READ ${FILENAME} content) -if(LOCAL_PROTOBUF) +if(NOT SYSTEM_PROTOBUF) # protobuf-3.6.0 pattern string( REPLACE @@ -77,7 +77,7 @@ if(LOCAL_PROTOBUF) file(WRITE ${SOURCE_FILENAME} "${content_cc}") endif() -endif() +endif(NOT SYSTEM_PROTOBUF) # constexpr int TensorBoundShape_DimType_DimType_ARRAYSIZE = TensorBoundShape_DimType_DimType_MAX + 1; # throws diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 4de2d79cb9757..99c41f24ab8c8 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -131,6 +131,7 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " USE_METAL : ${USE_METAL}") message(STATUS " USE_PYTORCH_METAL : ${USE_PYTORCH_METAL}") + message(STATUS " USE_PYTORCH_METAL_EXPORT : ${USE_PYTORCH_METAL_EXPORT}") message(STATUS " USE_FFTW : ${USE_FFTW}") message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}") message(STATUS " USE_MKLDNN : ${USE_MKLDNN}") @@ -178,6 +179,7 @@ function(caffe2_print_configuration_summary) message(STATUS " SELECTED_OP_LIST : ${SELECTED_OP_LIST}") endif() message(STATUS " USE_DEPLOY : ${USE_DEPLOY}") + message(STATUS " USE_BREAKPAD : ${USE_BREAKPAD}") message(STATUS " Public Dependencies : ${Caffe2_PUBLIC_DEPENDENCY_LIBS}") message(STATUS " Private Dependencies : ${Caffe2_DEPENDENCY_LIBS}") endfunction() diff --git a/docs/cpp/source/notes/maybe_owned.rst b/docs/cpp/source/notes/maybe_owned.rst new file mode 100644 index 0000000000000..8fa05f1b6aea7 --- /dev/null +++ b/docs/cpp/source/notes/maybe_owned.rst @@ -0,0 +1,59 @@ +MaybeOwned +================== + +``MaybeOwned`` is a C++ smart pointer class that dynamically +encodes whether a Tensor is *owned* or *borrowed*. It is used in +certain performance-sensitive situations to avoid unnecessarily +incrementing a Tensor’s reference count (at a small cost in +overhead from the extra indirection). + +.. warning:: + MaybeOwned must be used with **extreme** care. Claims of (non-)ownership + are not statically checked, and mistakes can cause reference undercounting + and use-after-free crashes. + + Due to this lack of safety net, we discourage the use of MaybeOwned + outside code paths that are known to be highly performance sensitive. + However, if you encounter pre-existing uses of MaybeOwned in code that + you want to modify, it’s critical to understand how to use it correctly. + +The primary use case for ``MaybeOwned`` is a function or method that +dynamically chooses between returning one of its arguments (typically +from a passthrough or “no-op” code path) and returning a freshly constructed +Tensor. Such a function would return a ``MaybeOwned`` in both cases, +the former in a "borrowed" state via a call to ``MaybeOwned::borrowed()``, +and the latter in an "owned" state via a call to ``MaybeOwned::owned()``. + +The canonical example is ``Tensor``'s ``expect_contiguous`` method, which shortcuts +and returns a borrowed self-reference when already contiguous: + +.. code-block:: cpp + + inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } + } + +Using the vocabulary of lifetimes, the essential safety requirement for borrowing +is that a borrowed Tensor must outlive any borrowing references to it. Here, for +example, we can safely borrow ``*this``, but the Tensor returned by +``__dispatch_contiguous()`` is freshly created, and borrowing a reference would +effectively leave it ownerless. + +So, general rules of thumb: + +- When in doubt, don’t use ``MaybeOwned`` at all - in particular, prefer + avoiding using it in code that doesn’t use it already. New usage should only be + introduced when critical (and demonstrable) performance gains result. + +- When modifying or calling code that already uses ``MaybeOwned``, remember + that it's always safe to produce a ``MaybeOwned`` from a Tensor in hand + via a call to ``MaybeOwned::owned()``. This may result in an unnecessary + reference count, but never in misbehavior - so it's always the safer bet, unless + the lifetime of the Tensor you're looking to wrap is crystal clear. + +More details and implementation code can be found at and +. diff --git a/docs/source/_static/img/meshgrid.png b/docs/source/_static/img/meshgrid.png new file mode 100644 index 0000000000000..97ad0661fc218 Binary files /dev/null and b/docs/source/_static/img/meshgrid.png differ diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst index 5958c639813f1..8aace1ef12ab8 100644 --- a/docs/source/autograd.rst +++ b/docs/source/autograd.rst @@ -189,10 +189,10 @@ When creating a new :class:`Function`, the following methods are available to `c :toctree: generated :nosignatures: - function._ContextMethodMixin.mark_dirty - function._ContextMethodMixin.mark_non_differentiable - function._ContextMethodMixin.save_for_backward - function._ContextMethodMixin.set_materialize_grads + function.FunctionCtx.mark_dirty + function.FunctionCtx.mark_non_differentiable + function.FunctionCtx.save_for_backward + function.FunctionCtx.set_materialize_grads .. _grad-check: @@ -252,6 +252,7 @@ You can define how these saved tensors should be packed / unpacked using hooks. A common application is to trade compute for memory by saving those intermediary results to disk or to CPU instead of leaving them on the GPU. This is especially useful if you notice your model fits on GPU during evaluation, but not training. +Also see :ref:`saved-tensors-hooks-doc`. .. autoclass:: torch.autograd.graph.saved_tensors_hooks diff --git a/docs/source/community/contribution_guide.rst b/docs/source/community/contribution_guide.rst index 166aa7526e731..7cba558dbdb54 100644 --- a/docs/source/community/contribution_guide.rst +++ b/docs/source/community/contribution_guide.rst @@ -200,8 +200,8 @@ Triaging issues ~~~~~~~~~~~~~~~ If you feel that an issue could benefit from a particular tag or level -of complexity comment on the issue and share your opinion. If you -feel an issue isn't categorized properly comment and let the team know. +of complexity, comment on the issue and share your opinion. If you +feel an issue isn't categorized properly, comment and let the team know. About open source development ----------------------------- diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index 5c1fcbf1c7ecb..b1d4954a65768 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -4,50 +4,47 @@ PyTorch Governance | Persons of Interest General Maintainers ------------------- -- Adam Paszke (`apaszke `__) - Soumith Chintala (`soumith `__) - Edward Yang (`ezyang `__) - Greg Chanan (`gchanan `__) - Dmytro Dzhulgakov (`dzhulgakov `__) -- (sunsetting) Sam Gross - (`colesbury `__) +- (emeritus) Sam Gross (`colesbury `__) +- (emeritus) Adam Paszke (`apaszke `__) Module-level maintainers ------------------------ -torch.* -~~~~~~~ - -- Greg Chanan (`gchanan `__) -- Soumith Chintala (`soumith `__) -- [linear algebra] Vishwak Srinivasan (`vishwakftw `__) - torch.nn ~~~~~~~~ -- Adam Paszke (`apaszke `__) - Greg Chanan (`gchanan `__) - Soumith Chintala (`soumith `__) -- Sam Gross (`colesbury `__) +- Joel Schlosser (`jbschlosser `__) +- (emeritus) Sam Gross (`colesbury `__) +- (emeritus) Adam Paszke (`apaszke `__) torch.optim ~~~~~~~~~~~ -- Vincent Quenneville-Belair (`vincentqb `__) - Soumith Chintala (`soumith `__) +- Ilqar Ramazanli (`iramazanli `__) +- (emeritus) Vincent Quenneville-Belair (`vincentqb `__) -Autograd Engine -~~~~~~~~~~~~~~~ +torch.autograd +~~~~~~~~~~~~~~ - Edward Yang (`ezyang `__) - Alban Desmaison (`alband `__) -- Adam Paszke (`apaszke `__) +- (emeritus) Adam Paszke (`apaszke `__) -JIT -~~~ +JIT / TorchScript / FX +~~~~~~~~~~~~~~~~~~~~~~ -- Zach Devito (`zdevito `__) - Michael Suo (`suo `__) +- Yanan Cao (`gmagogsfm `__) +- James Reed (`jamesr66a `__) +- (emeritus) Zach Devito (`zdevito `__) + Distributions & RNG ~~~~~~~~~~~~~~~~~~~ @@ -60,78 +57,122 @@ Distributions & RNG Distributed ~~~~~~~~~~~ -- Pieter Noordhuis (`pietern `__) - Shen Li (`mrshenli `__) -- (proposed) Pritam Damania - (`pritamdamania87 `__) +- Pritam Damania (`pritamdamania87 `__) +- (emeritus) Pieter Noordhuis (`pietern `__) Multiprocessing and DataLoaders ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Vitaly Fedyunin (`VitalyFedyunin `__) - Simon Wang (`SsnL `__) -- Adam Paszke (`apaszke `__) +- (emeritus) Adam Paszke (`apaszke `__) + +torch.linalg / Linear Algebra +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Mike Ruberry (`mruberry `__) +- Vishwak Srinivasan (`vishwakftw `__) +- Ivan Yashchuk (`IvanYashchuk `__) + +torch.fft +~~~~~~~~~ + +- Mike Ruberry (`mruberry `__) +- Peter Bell (`peterbell10 `__) + CPU Performance / SIMD ~~~~~~~~~~~~~~~~~~~~~~ -- Xiaoqiang Zheng (`zheng-xq `__) - Vitaly Fedyunin (`VitalyFedyunin `__) -- Sam Gross (`colesbury `__) -- (sunsetting) Christian Puhrsch (`cpuhrsch `__) -- [threading] Ilia Cherniavskii (`ilia-cher `__) +- (emeritus) Xiaoqiang Zheng (`zheng-xq `__) +- (emeritus) Sam Gross (`colesbury `__) +- (emeritus) Christian Puhrsch (`cpuhrsch `__) +- (emeritus) Ilia Cherniavskii (`ilia-cher `__) CUDA ~~~~ - Natalia Gimelshein (`ngimel `__) - Edward Yang (`ezyang `__) -- Xiaoqiang Zheng (`zheng-xq `__) +- Piotr Bialecki (`ptrblck `__) +- (emeritus) Xiaoqiang Zheng (`zheng-xq `__) MKLDNN ~~~~~~ -- Junjie Bai (`bddppq `__) -- Yinghai Lu (`yinghai `__) +- Vitaly Fedyunin (`VitalyFedyunin `__) +- Jianhui Li (`Jianhui-Li `__) +- (emeritus) Junjie Bai (`bddppq `__) +- (emeritus) Yinghai Lu (`yinghai `__) AMD/ROCm/HIP ~~~~~~~~~~~~ -- Junjie Bai (`bddppq `__) -- Johannes M. Dieterich (`iotamudelta `__) +- Peng Sun (`sunway513 `__) +- Jithun Nair (`jithunnair-amd `__) +- Jeff Daily (`jeffdaily `__) +- (emeritus) Junjie Bai (`bddppq `__) Build + CI ~~~~~~~~~~ -- Will Feng (`yf225 `__) -- Edward Yang (`ezyang `__) -- Soumith Chintala (`soumith `__) -- Karl Ostmo (`kostmo `__) -- Hong Xu (`xuhdev `__) +- Nikita Shulga (`malfet `__) +- Eli Uriegas (`seemethere `__) +- Zhuojie Zhou (`zhouzhuojie `__) +- (emeritus) Edward Yang (`ezyang `__) +- (emeritus) Karl Ostmo (`kostmo `__) -Benchmarks -~~~~~~~~~~ +Performance Tools +~~~~~~~~~~~~~~~~~ -- Mingzhe Li (`mingzhe09088 `__) +- Victor Bittorf (`bitfort `__) +- Gisle Dankel (`gdankel `__) +- Taylor Robie (`robieta `__) +- Xu Zhao (`xuzhao9 `__) +- Geeta Chauhan (`chauhang `__) +- (emeritus) Natalia Gimelshein (`ngimel `__) +- (emeritus) Mingzhe Li (`mingzhe09088 `__) C++ API ~~~~~~~ -- Will Feng (`yf225 `__) +- Joel Schlosser (`jbschlosser `__) +- (emeritus) Will Feng (`yf225 `__) C10 utils and operator dispatch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- Sebastian Messmer (`smessmer `__) +- Brian Hirsh (`bdhirsh `__) +- Edward Yang (`ezyang `__) - Dmytro Dzhulgakov (`dzhulgakov `__) +- (emeritus) Sebastian Messmer (`smessmer `__) ONNX <-> PyTorch ~~~~~~~~~~~~~~~~ - -- Lu Fang (`houseroad `__) -- Lara Haidar (`lara-hdr `__) -- Spandan Tiwari (`spandantiwari `__) +- Negin Raoof (`neginraoof `__) +- Gary Miguel (`garymm `__) - Bowen Bao (`BowenBao `__) +- (emeritus) Lu Fang (`houseroad `__) +- (emeritus) Lara Haidar (`lara-hdr `__) +- (emeritus) Spandan Tiwari (`spandantiwari `__) + +Mobile / Edge +~~~~~~~~~~~~~ +- David Reiss (`dreiss `__) +- Raziel Guevara (`raziel `__) +- Linbin Yu (`linbinyu `__) +- Ivan Kobzarev (`IvanKobzarev `__) +- Tao Xu (`xta0 `__) + +Model Compression & Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- Raghuraman Krishnamoorthi (`raghuramank100 `__) +- Jerry Zhang (`jerryzh168 `__) +- Zafar Takhirov (`z-a-f `__) +- Supriya Rao (`supriyar `__) + Windows ~~~~~~~ @@ -151,31 +192,39 @@ Library-level maintainers XLA ~~~ -- Ailing Zhang (`ailzhang `__) +- Jack Cao (`JackCaoG `__) +- Daniel Sohn (`jysohn23 `__) +- Zach Cain (`zcain117 `__) +- Brian Hirsch (`bdhirsh `__) - Gregory Chanan (`gchanan `__) -- Davide Libenzi (`dlibenzi `__) -- Alex Suhan (`asuhan `__) +- (emeritus) Ailing Zhang (`ailzhang `__) +- (emeritus) Davide Libenzi (`dlibenzi `__) +- (emeritus) Alex Suhan (`asuhan `__) TorchServe ~~~~~~~~~~ -- Geeta Chauhan (`chauhang `__) -- Manoj Rao (`mycpuorg `__) -- Vamshi Dantu (`vdantu `__) -- Dhanasekar Karuppasamy (`dhanainme `__) +- Geeta Chauhan (`chauhang `__) +- Manoj Rao (`mycpuorg `__) +- Vamshi Dantu (`vdantu `__) +- Dhanasekar Karuppasamy (`dhanainme `__) TorchVision ~~~~~~~~~~~ -- Francisco Massa (`fmassa `__) +- Francisco Massa (`fmassa `__) +- Vasilis Vryniotis (`datumbox `__) TorchText ~~~~~~~~~ -- Guanheng George Zhang (`zhangguanheng66 `__) -- Christian Puhrsch (`cpuhrsch `__) +- Parmeet Singh Bhatia (`parmeet `__) +- Steven Liu (`hudeven `__) +- (emeritus) Guanheng George Zhang (`zhangguanheng66 `__) +- (emeritus) Christian Puhrsch (`cpuhrsch `__) TorchAudio ~~~~~~~~~~ -- Vincent QB (`vincentqb `__) +- Moto Hira (`mthrok `__) +- (emeritus) Vincent QB (`vincentqb `__) diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index d4783c867b82a..75029332aa481 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -71,6 +71,17 @@ Streams and events Stream Event +Graphs (prototype) +------------------ +.. autosummary:: + :toctree: generated + :nosignatures: + + graph_pool_handle + CUDAGraph + graph + make_graphed_callables + Memory management ----------------- .. autosummary:: diff --git a/docs/source/data.rst b/docs/source/data.rst index 9135c87d09262..b03fcb5858531 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -264,6 +264,21 @@ Setting the argument :attr:`num_workers` as a positive integer will turn on multi-process data loading with the specified number of loader worker processes. +.. warning:: + After several iterations, the loader worker processes will consume + the same amount of CPU memory as the parent process for all Python + objects in the parent process which are accessed from the worker + processes. This can be problematic if the Dataset contains a lot of + data (e.g., you are loading a very large list of filenames at Dataset + construction time) and/or you are using a lot of workers (overall + memory usage is ``number of workers * size of parent process``). The + simplest workaround is to replace Python objects with non-refcounted + representations such as Pandas, Numpy or PyArrow objects. Check out + `issue #13246 + `_ + for more details on why this occurs and example code for how to + workaround these problems. + In this mode, each time an iterator of a :class:`~torch.utils.data.DataLoader` is created (e.g., when you call ``enumerate(dataloader)``), :attr:`num_workers` worker processes are created. At this point, the :attr:`dataset`, diff --git a/docs/source/ddp_comm_hooks.rst b/docs/source/ddp_comm_hooks.rst index aed70c0752825..d0f11fe0b0412 100644 --- a/docs/source/ddp_comm_hooks.rst +++ b/docs/source/ddp_comm_hooks.rst @@ -44,11 +44,13 @@ The input ``bucket`` is a :class:`torch.distributed.GradBucket` object. .. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks .. autofunction:: allreduce_hook .. autofunction:: fp16_compress_hook +.. autofunction:: bf16_compress_hook -Additionally, a communication hook wraper is provided to support :meth:`~fp16_compress_hook` as a wrapper, +Additionally, a communication hook wraper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper, which can be combined with other communication hooks. .. autofunction:: fp16_compress_wrapper +.. autofunction:: bf16_compress_wrapper PowerSGD Communication Hook --------------------------- @@ -82,6 +84,18 @@ PowerSGD Hooks .. autofunction:: powerSGD_hook .. autofunction:: batched_powerSGD_hook +Debugging Communication Hooks +----------------------------- + +As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose. + +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks + +.. warning :: + Debugging communication hooks do not necessarily output the correct results. + +.. autofunction:: noop_hook + Acknowledgements ---------------- diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0f4e051bbf4db..c5cd727fa7ea0 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -180,6 +180,8 @@ joined. .. autofunction:: is_nccl_available +.. autofunction:: is_torchelastic_launched + -------------------------------------------------------------------------------- Currently three initialization methods are supported: diff --git a/docs/source/elastic/quickstart.rst b/docs/source/elastic/quickstart.rst index 1d22426d06a8b..f7c1ebf7dd0de 100644 --- a/docs/source/elastic/quickstart.rst +++ b/docs/source/elastic/quickstart.rst @@ -5,13 +5,13 @@ To launch a **fault-tolerant** job, run the following on all nodes. .. code-block:: bash - python -m torch.distributed.run - --nnodes=NUM_NODES - --nproc_per_node=TRAINERS_PER_NODE - --rdzv_id=JOB_ID - --rdzv_backend=c10d - --rdzv_endpoint=HOST_NODE_ADDR - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + torchrun + --nnodes=NUM_NODES + --nproc_per_node=TRAINERS_PER_NODE + --rdzv_id=JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) To launch an **elastic** job, run the following on at least ``MIN_SIZE`` nodes @@ -19,13 +19,13 @@ and at most ``MAX_SIZE`` nodes. .. code-block:: bash - python -m torch.distributed.run - --nnodes=MIN_SIZE:MAX_SIZE - --nproc_per_node=TRAINERS_PER_NODE - --rdzv_id=JOB_ID - --rdzv_backend=c10d - --rdzv_endpoint=HOST_NODE_ADDR - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + torchrun + --nnodes=MIN_SIZE:MAX_SIZE + --nproc_per_node=TRAINERS_PER_NODE + --rdzv_id=JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) ``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and the port on which the C10d rendezvous backend should be @@ -46,6 +46,6 @@ ideally you should pick a node that has a high bandwidth. Learn more about writing your distributed training script `here `_. -If ``torch.distributed.run`` does not meet your requirements you may use our -APIs directly for more powerful customization. Start by taking a look at the -`elastic agent `_ API). +If ``torchrun`` does not meet your requirements you may use our APIs directly +for more powerful customization. Start by taking a look at the +`elastic agent `_ API. diff --git a/docs/source/elastic/run.rst b/docs/source/elastic/run.rst index fb870fae41f58..284fc7f755311 100644 --- a/docs/source/elastic/run.rst +++ b/docs/source/elastic/run.rst @@ -1,6 +1,6 @@ .. _launcher-api: -torch.distributed.run (Elastic Launch) +torchrun (Elastic Launch) ====================================== .. automodule:: torch.distributed.run diff --git a/docs/source/elastic/train_script.rst b/docs/source/elastic/train_script.rst index 263f2df659574..04225d79067a8 100644 --- a/docs/source/elastic/train_script.rst +++ b/docs/source/elastic/train_script.rst @@ -4,7 +4,7 @@ Train script ------------- If your train script works with ``torch.distributed.launch`` it will continue -working with ``torch.distributed.run`` with these differences: +working with ``torchrun`` with these differences: 1. No need to manually pass ``RANK``, ``WORLD_SIZE``, ``MASTER_ADDR``, and ``MASTER_PORT``. diff --git a/docs/source/jit.rst b/docs/source/jit.rst index eeb0d2a2c4ac3..8a80b6471e1a7 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -60,6 +60,7 @@ Creating TorchScript Code ScriptModule ScriptFunction freeze + optimize_for_inference save load ignore @@ -475,7 +476,7 @@ In this case, data-dependent control flow like this can be captured using #print(str(scripted_fn.graph).strip()) for input_tuple in [inputs] + check_inputs: - torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple)) + torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple)) .. testoutput:: :hide: diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 7a286d3d4051e..ffca583b706e9 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -67,6 +67,7 @@ Matrix Products :toctree: generated :nosignatures: + matmul matrix_power multi_dot householder_product diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 07ce4db2f48af..6eca9d4b16b6a 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -389,6 +389,7 @@ in :func:`torch.nn.utils.parameterize.register_parametrization`. :toctree: generated :nosignatures: + parametrizations.orthogonal parametrizations.spectral_norm Utility functions to parametrize Tensors on existing Modules. @@ -396,7 +397,7 @@ Note that these functions can be used to parametrize a given Parameter or Buffer given a specific function that maps from an input space to the parametrized space. They are not parameterizations that would transform an object into a parameter. See the -`Parametrizations `__ tutorial +`Parametrizations tutorial `_ for more information on how to implement your own parametrizations. .. autosummary:: diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index 0c1eed3f42457..2a59d976e9a6a 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -36,6 +36,57 @@ flow statements, that can change the overall shape and size of the graph at every iteration. You don't have to encode all possible paths before you launch the training - what you run is what you differentiate. +.. _saved-tensors-doc: + +Saved tensors +^^^^^^^^^^^^^ + +Some operations need intermediary results to be saved during the forward pass +in order to execute the backward pass. For example, the function +:math:`x\mapsto x^2` saves the input :math:`x` to compute the gradient. + +When defining a custom Python :class:`~torch.autograd.Function`, you can use +:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` to save +tensors during the forward pass and +:attr:`~torch.autograd.function.Function.saved_tensors` to retrieve them +during the backward pass. See :doc:`/notes/extending` for more information. + +For operations that PyTorch defines (e.g. :func:`torch.pow`), tensors are +automatically saved as needed. You can explore (for educational or debugging +purposes) which tensors are saved by a certain ``grad_fn`` by looking for its +attributes starting with the prefix ``_saved``. + +.. code:: + + x = torch.randn(5, requires_grad=True) + y = x.pow(2) + print(x.equal(y.grad_fn._saved_self)) # True + print(x is y.grad_fn._saved_self) # True + + +In the previous code, ``y.grad_fn._saved_self`` refers to the same Tensor object as `x`. +But that may not always be the case. For instance: + +.. code:: + + x = torch.randn(5, requires_grad=True) + y = x.exp() + print(y.equal(y.grad_fn._saved_result)) # True + print(y is y.grad_fn._saved_result) # False + + +Under the hood, to prevent reference cycles, PyTorch has *packed* the tensor +upon saving and *unpacked* it into a different tensor for reading. Here, the +tensor you get from accessing ``y.grad_fn._saved_result`` is a different tensor +object than ``x`` (but they still share the same storage). + +Whether a tensor will be packed into a different tensor object depends on +whether it is an output of its own `grad_fn`, which is an implementation detail +subject to change and that users should not rely on. + +You can control how PyTorch does packing / unpacking with :ref:`saved-tensors-hooks-doc`. + + .. _locally-disable-grad-doc: Locally disabling gradient computation @@ -598,3 +649,151 @@ chain rule: .. math:: \frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}}) + +.. _saved-tensors-hooks-doc: + +Hooks for saved tensors +----------------------- + +You can control :ref:`how saved tensors are packed / unpacked +` by defining a pair of ``pack_hook`` / ``unpack_hook`` +hooks. The ``pack_hook`` function should take a tensor as its single argument +but can return any python object (e.g. another tensor, a tuple, or even a +string containing a filename). The ``unpack_hook`` function takes as its single +argument the output of ``pack_hook`` and should return a tensor to be used in +the backward pass. The tensor returned by ``unpack_hook`` only needs to have +the same content as the tensor passed as input to ``pack_hook``. In particular, +any autograd-related metadata can be ignored as they will be overwritten during +unpacking. + +An example of such pair is: + +.. code:: + + class SelfDeletingTempFile(): + def __init__(self): + self.name = os.path.join(tmp_dir, str(uuid.uuid4())) + + def __del__(self): + os.remove(self.name) + + def pack_hook(tensor): + temp_file = SelfDeletingTempFile() + torch.save(tensor, temp_file.name) + return temp_file + + def unpack_hook(temp_file): + return torch.load(temp_file.name) + +Notice that the ``unpack_hook`` should not delete the temporary file because it +might be called multiple times: the temporary file should be alive for as long +as the returned `SelfDeletingTempFile` object is alive. In the above example, +we prevent leaking the temporary file by closing it when it is no longer needed +(on deletion of the `SelfDeletingTempFile` object). + +.. note:: + + We guarantee that ``pack_hook`` will only be called once but ``unpack_hook`` can + be called as many times as the backward pass requires it and we expect it to + return the same data each time. + +.. warning:: + + Performing inplace operations on the input of any of the functions is forbidden + as they may lead to unexpected side-effects. PyTorch will throw an error if the + input to a pack hook is modified inplace but does not catch the case where the + input to an unpack hook is modified inplace. + + +Registering hooks for a saved tensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can register a pair of hooks on a saved tensor by calling the +:meth:`~torch.autograd.SavedTensor.register_hooks` method on a +:class:`SavedTensor` object. Those objects are exposed as attributes of a +``grad_fn`` and start with the ``_raw_saved_`` prefix. + +.. code:: + + x = torch.randn(5, requires_grad=True) + y = x.pow(2) + y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook) + +The ``pack_hook`` method is called as soon as the pair is registered. +The ``unpack_hook`` method is called each time the saved tensor needs to be +accessed, either by means of ``y.grad_fn._saved_self`` or during the backward +pass. + +.. warning:: + + If you maintain a reference to a :class:`SavedTensor` after the saved + tensors have been released (i.e. after backward has been called), calling + its :meth:`~torch.autograd.SavedTensor.register_hooks` is forbidden. + PyTorch will throw an error most of the time but it may fail + to do so in some cases and undefined behavior may arise. + +Registering default hooks for saved tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, you can use the context-manager +:class:`~torch.autograd.graph.saved_tensors_hooks` to register a pair of +hooks which will be applied to *all* saved tensors that are created in +that context. + +Example: + +.. code:: + + # Only save on disk tensors that have size >= 1000 + SAVE_ON_DISK_THRESHOLD = 1000 + + def pack_hook(x): + if x.numel() < SAVE_ON_DISK_THRESHOLD: + return x + temp_file = SelfDeletingTempFile() + torch.save(tensor, temp_file.name) + return temp_file + + def unpack_hook(tensor_or_sctf): + if isinstance(tensor_or_sctf, torch.Tensor): + return tensor_or_sctf + return torch.load(tensor_or_sctf.name) + + class Model(nn.Module): + def forward(self, x): + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + # ... compute output + output = x + return output + + model = Model() + net = nn.DataParallel(model) + + + +The hooks defined with this context manager are thread-local. +Hence, the following code will not produce the desired effects because the hooks do not go +through `DataParallel`. + +.. code:: + + # Example what NOT to do + + net = nn.DataParallel(model) + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + output = net(input) + + +Note that using those hooks disables all the optimization in place to reduce +Tensor object creation. For example: + +.. code:: + + with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): + x = torch.randn(5, requires_grad=True) + y = x * x + +Without the hooks, ``x``, ``y.grad_fn._saved_self`` and +``y.grad_fn._saved_other`` all refer to the same tensor object. +With the hooks, PyTorch will pack and unpack `x` into two new tensor objects +that share the same storage with the original `x` (no copy performed). diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 264017f0203cc..5d7c0ea48f669 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -262,7 +262,7 @@ have the same stream-semantics relationship as any group of ops:: BC note: Using grads on the default stream ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In prior versions of Pytorch (1.9 and earlier), the autograd engine always synced +In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced the default stream with all backward ops, so the following pattern:: with torch.cuda.stream(s): @@ -270,7 +270,7 @@ the default stream with all backward ops, so the following pattern:: use grads was safe as long as ``use grads`` happened on the default stream. -In present Pytorch, that pattern is no longer safe. If ``backward()`` +In present PyTorch, that pattern is no longer safe. If ``backward()`` and ``use grads`` are in different stream contexts, you must sync the streams:: with torch.cuda.stream(s): @@ -513,3 +513,452 @@ by GIL of Python interpreter. If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use `torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`. + +.. _cuda-graph-semantics: + +CUDA Graphs +----------- + +A CUDA graph is a record of the work (mostly kernels and their arguments) that a +CUDA stream and its dependent streams perform. +For general principles and details on the underlying CUDA API, see +`Getting Started with CUDA Graphs`_ and the +`Graphs section`_ of the CUDA C Programming Guide. + +PyTorch supports the construction of CUDA graphs using `stream capture`_, which puts a +CUDA stream in *capture mode*. CUDA work issued to a capturing stream doesn't actually +run on the GPU. Instead, the work is recorded in a graph. + +After capture, the graph can be *launched* to run the GPU work as many times as needed. +Each replay runs the same kernels with the same arguments. For pointer arguments this +means the same memory addresses are used. +By filling input memory with new data (e.g., from a new batch) before each replay, +you can rerun the same work on new data. + +Why CUDA Graphs? +^^^^^^^^^^^^^^^^ + +Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange for +**greatly reduced CPU overhead**. A graph's arguments and kernels are fixed, so a graph replay +skips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driver +overheads. Under the hood, a replay submits the entire graph's work to the GPU with +a single call to `cudaGraphLaunch`_. Kernels in a replay also execute slightly faster +on the GPU, but eliding CPU overhead is the main benefit. + +You should try CUDA graphs if all or part of your network is graph-safe (usually this means +static shapes and static control flow, but see the other :ref:`constraints`) +and you suspect its runtime is at least somewhat CPU-limited. + +.. _Getting Started with CUDA Graphs: + https://developer.nvidia.com/blog/cuda-graphs/ +.. _Graphs section: + https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-graphs +.. _stream capture: + https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture +.. _cudaGraphLaunch: + https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 + +PyTorch API +^^^^^^^^^^^ + +.. warning:: + This API is a prototype and may change in future releases. + +PyTorch exposes graphs via a raw :class:`torch.cuda.CUDAGraph` class +and two convenience wrappers, +:class:`torch.cuda.graph` and +:class:`torch.cuda.make_graphed_callables`. + +:class:`torch.cuda.graph` is a simple, versatile context manager that +captures CUDA work in its context. +Before capture, warm up the workload to be captured by running +a few eager iterations. Warmup must occur on a side stream. +Because the graph reads from and writes to the same memory addresses in every +replay, you must maintain long-lived references to tensors that hold +input and output data during capture. +To run the graph on new input data, copy new data to the capture's input tensor(s), +replay the graph, then read the new output from the capture's output tensor(s). +Example:: + + g = torch.cuda.CUDAGraph() + + # Placeholder input used for capture + static_input = torch.empty((5,), device="cuda") + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + static_output = static_input * 2 + torch.cuda.current_stream().wait_stream(s) + + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + with torch.cuda.graph(g): + static_output = static_input * 2 + + # Fills the graph's input memory with new data to compute on + static_input.copy_(torch.full((5,), 3, device="cuda")) + g.replay() + # static_output holds the results + print(static_output) # full of 3 * 2 = 6 + + # Fills the graph's input memory with more data to compute on + static_input.copy_(torch.full((5,), 4, device="cuda")) + g.replay() + print(static_output) # full of 4 * 2 = 8 + +See +:ref:`Whole-network capture`, +:ref:`Usage with torch.cuda.amp`, and +:ref:`Usage with multiple streams` +for realistic and advanced patterns. + +:class:`~torch.cuda.make_graphed_callables` is more sophisticated. +:class:`~torch.cuda.make_graphed_callables` accepts Python functions and +:class:`torch.nn.Module`\s. For each passed function or Module, +it creates separate graphs of the forward-pass and backward-pass work. See +:ref:`Partial-network capture`. + +.. _capture-constraints: + +Constraints +~~~~~~~~~~~ + +A set of ops is *capturable* if it doesn't violate any of the following constraints. + +Constraints apply to all work in a +:class:`torch.cuda.graph` context and all work in the forward and backward passes +of any callable you pass to :func:`torch.cuda.make_graphed_callables`. + +Violating any of these will likely cause a runtime error: + +* Capture must occur on a non-default stream. (This is only a concern if you use the raw + :meth:`CUDAGraph.capture_begin` and + :meth:`CUDAGraph.capture_end` calls. + :class:`~torch.cuda.graph` and + :func:`~torch.cuda.make_graphed_callables` set a side stream for you.) +* Ops that sychronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited. +* CUDA RNG ops are allowed, but must use default generators. For example, explicitly constructing a + new :class:`torch.Generator` instance and passing it as the ``generator`` argument to an RNG function + is prohibited. + +Violating any of these will likely cause silent numerical errors or undefined behavior: + +* Within a process, only one capture may be underway at a time. +* No non-captured CUDA work may run in this process (on any thread) while capture is underway. +* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay. +* Every replay reads from and writes to the same (virtual) memory addresses. +* Dynamic control flow (based on CPU or GPU data) is prohibited. +* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence + has the same size and layout in every replay. +* Using multiple streams in a capture is allowed, but there are :ref:`restrictions`. + +Non-constraints +~~~~~~~~~~~~~~~ + +* Once captured, the graph may be replayed on any stream. + +.. _whole-network-capture: + +Whole-network capture +^^^^^^^^^^^^^^^^^^^^^^ + +If your entire network is capturable, you can capture and replay an entire iteration:: + + N, D_in, H, D_out = 640, 4096, 2048, 1024 + model = torch.nn.Sequential(torch.nn.Linear(D_in, H), + torch.nn.Dropout(p=0.2), + torch.nn.Linear(H, D_out), + torch.nn.Dropout(p=0.1)).cuda() + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + # Placeholders used for capture + static_input = torch.randn(N, D_in, device='cuda') + static_target = torch.randn(N, D_out, device='cuda') + + # warmup + # Uses static_input and static_target here for convenience, + # but in a real setting, because the warmup includes optimizer.step() + # you must use a few batches of real data. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + optimizer.zero_grad(set_to_none=True) + y_pred = model(static_input) + loss = loss_fn(y_pred, static_target) + loss.backward() + optimizer.step() + torch.cuda.current_stream().wait_stream(s) + + # capture + g = torch.cuda.CUDAGraph() + # Sets grads to None before capture, so backward() will create + # .grad attributes with allocations from the graph's private pool + optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + static_y_pred = model(static_input) + static_loss = loss_fn(static_y_pred, static_target) + static_loss.backward() + optimizer.step() + + real_inputs = [torch.rand_like(static_input) for _ in range(10)] + real_targets = [torch.rand_like(static_target) for _ in range(10)] + + for data, target in zip(real_inputs, real_targets): + # Fills the graph's input memory with new data to compute on + static_input.copy_(data) + static_target.copy_(target) + # replay() includes forward, backward, and step. + # You don't even need to call optimizer.zero_grad() between iterations + # because the captured backward refills static .grad tensors in place. + g.replay() + # Params have been updated. static_y_pred, static_loss, and .grad + # attributes hold values from computing on this iteration's data. + +.. _partial-network-capture: + +Partial-network capture +^^^^^^^^^^^^^^^^^^^^^^^^^ + +If some of your network is unsafe to capture (e.g., due to dynamic control flow, +dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe +part(s) eagerly and use :func:`torch.cuda.make_graphed_callables` to graph only +the capture-safe part(s). + +By default, callables returned by :func:`~torch.cuda.make_graphed_callables` +are autograd-aware, and can be used in the training loop as direct replacements +for the functions or :class:`nn.Module`\ s you passed. + +:func:`~torch.cuda.make_graphed_callables` internally creates +:class:`~torch.cuda.CUDAGraph` objects, runs warmup iterations, and maintains +static inputs and outputs as needed. Therefore (unlike with +:class:`torch.cuda.graph`) you don't need to handle those manually. + +In the following example, data-dependent dynamic control flow means the +network isn't capturable end-to-end, but +:func:`~torch.cuda.make_graphed_callables` +lets us capture and run graph-safe sections as graphs regardless:: + + N, D_in, H, D_out = 640, 4096, 2048, 1024 + + module1 = torch.nn.Linear(D_in, H).cuda() + module2 = torch.nn.Linear(H, D_out).cuda() + module3 = torch.nn.Linear(H, D_out).cuda() + + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.SGD(chain(module1.parameters() + + module2.parameters() + + module3.parameters()), + lr=0.1) + + # Sample inputs used for capture + # requires_grad state of sample inputs must match + # requires_grad state of real inputs each callable will see. + x = torch.randn(N, D_in, device='cuda') + h = torch.randn(N, H, device='cuda', requires_grad=True) + + module1 = torch.cuda.make_graphed_callables(module1, (x,)) + module2 = torch.cuda.make_graphed_callables(module2, (h,)) + module3 = torch.cuda.make_graphed_callables(module3, (h,)) + + real_inputs = [torch.rand_like(x) for _ in range(10)] + real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)] + + for data, target in zip(real_inputs, real_targets): + optimizer.zero_grad(set_to_none=True) + + tmp = module1(data) # forward ops run as a graph + + if tmp.sum().item() > 0: + tmp = module2(tmp) # forward ops run as a graph + else: + tmp = module3(tmp) # forward ops run as a graph + + loss = loss_fn(tmp, y) + # module2's or module3's (whichever was chosen) backward ops, + # as well as module1's backward ops, run as graphs + loss.backward() + optimizer.step() + +.. _graphs-with-amp: + +Usage with torch.cuda.amp +^^^^^^^^^^^^^^^^^^^^^^^^^ + +For typical optimizers, :meth:`GradScaler.step` syncs +the CPU with the GPU, which is prohibited during capture. To avoid errors, either use +:ref:`partial-network capture`, or (if forward, loss, +and backward are capture-safe) capture forward, loss, and backward but not the +optimizer step:: + + # warmup + # In a real setting, use a few batches of real data. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + optimizer.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast(): + y_pred = model(static_input) + loss = loss_fn(y_pred, static_target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + torch.cuda.current_stream().wait_stream(s) + + # capture + g = torch.cuda.CUDAGraph() + optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + with torch.cuda.amp.autocast(): + static_y_pred = model(static_input) + static_loss = loss_fn(static_y_pred, static_target) + scaler.scale(static_loss).backward() + # don't capture scaler.step(optimizer) or scaler.update() + + real_inputs = [torch.rand_like(static_input) for _ in range(10)] + real_targets = [torch.rand_like(static_target) for _ in range(10)] + + for data, target in zip(real_inputs, real_targets): + static_input.copy_(data) + static_target.copy_(target) + g.replay() + # Runs scaler.step and scaler.update eagerly + scaler.step(optimizer) + scaler.update() + +.. _multistream-capture: + +Usage with multiple streams +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Capture mode automatically propagates to any streams that sync with a capturing stream. +Within capture, you may expose parallelism by issuing calls to different streams, +but the overall stream dependency DAG must branch out from the +initial capturing stream after capture begins and rejoin the initial stream +before capture ends:: + + with torch.cuda.graph(g): + # at context manager entrance, torch.cuda.current_stream() + # is the initial capturing stream + + # INCORRECT (does not branch out from or rejoin initial stream) + with torch.cuda.stream(s): + cuda_work() + + # CORRECT: + # branches out from initial stream + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + cuda_work() + # rejoins initial stream before capture ends + torch.cuda.current_stream().wait_stream(s) + +.. note:: + + To avoid confusion for power users looking at replays in nsight systems or nvprof: + Unlike eager execution, the graph interprets a nontrivial stream DAG in capture + as a hint, not a command. During replay, the graph may reorganize independent ops + onto different streams or enqueue them in a different order (while respecting your + original DAG's overall dependencies). + +Usage with DistributedDataParallel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +NCCL < 2.9.6 +~~~~~~~~~~~~ + +NCCL versions earlier than 2.9.6 don't allow collectives to be captured. +You must use :ref:`partial-network capture`, +which defers allreduces to happen outside graphed sections of backward. + +Call :func:`~torch.cuda.make_graphed_callables` on graphable network sections +*before* wrapping the network with DDP. + +NCCL >= 2.9.6 +~~~~~~~~~~~~~ + +NCCL versions 2.9.6 or later allow collectives in the graph. +Approaches that capture an :ref:`entire backward pass` +are a viable option, but need three setup steps. + +1. Disable DDP's internal async error handling:: + + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + torch.distributed.init_process_group(...) + +2. Before full-backward capture, DDP must be constructed in a side-stream context:: + + with torch.cuda.stream(s): + model = DistributedDataParallel(model) + +3. Your warmup must run at least 11 DDP-enabled eager iterations before capture. + +.. _graph-memory-management: + +Graph memory management +^^^^^^^^^^^^^^^^^^^^^^^ + +A captured graph acts on the same virtual addresses every time it replays. +If PyTorch frees the memory, a later replay can hit an illegal memory access. +If PyTorch reassigns the memory to new tensors, the replay can corrupt the values +seen by those tensors. Therefore, the virtual addresses used by the graph must be +reserved for the graph across replays. The PyTorch caching allocator achieves this +by detecting when capture is underway and satisfying the capture's allocations +from a graph-private memory pool. The private pool stays alive until its +:class:`~torch.cuda.CUDAGraph` object and all tensors created during capture +go out of scope. + +Private pools are maintained automatically. By default, the allocator creates a +separate private pool for each capture. If you capture multiple graphs, +this conservative approach ensures graph replays never corrupt each other's values, +but sometimes needlessly wastes memory. + +To economize the memory stashed in private pools, :class:`torch.cuda.graph` +and :func:`torch.cuda.make_graphed_callables` optionally allow different +captures to share the same private pool. +It's safe for a set of graphs to share a private pool if you know they'll always +be replayed in the same order they were captured, +and never be replayed concurrently. + +Sharing memory across captures with torch.cuda.graph +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:class:`torch.cuda.graph`'s ``pool`` argument is a hint to use a particular private pool, +and can be used to share memory across graphs as shown:: + + g1 = torch.cuda.CUDAGraph() + g2 = torch.cuda.CUDAGraph() + + # (create static inputs for g1 and g2, run warmups of their workloads...) + + # Captures g1 + with torch.cuda.graph(g1): + static_out_1 = g1_workload(static_in_1) + + # Captures g2, hinting that g2 may share a memory pool with g1 + with torch.cuda.graph(g2, pool=g1.pool()): + static_out_2 = g2_workload(static_in_2) + + static_in_1.copy_(real_data_1) + static_in_2.copy_(real_data_2) + g1.replay() + g2.replay() + +Sharing memory across captures with torch.cuda.make_graphed_callables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +With :func:`torch.cuda.make_graphed_callables`, if you want to graph several +callables and you know they'll always run in the same order (and never concurrently) +pass them as a tuple in the same order they'll run in the live workload, and +:func:`~torch.cuda.make_graphed_callables` will capture their graphs using a shared +private pool. + +If, in the live workload, your callables will run in an order that occasionally changes, +or if they'll run concurrently, passing them as a tuple to a single invocation of +:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call +:func:`~torch.cuda.make_graphed_callables` separately for each one. diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 1c89bcf10eb0a..a8d3983f9f0d9 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -13,60 +13,110 @@ Extending :mod:`torch.autograd` .. currentmodule:: torch.autograd Adding operations to :mod:`~torch.autograd` requires implementing a new -:class:`Function` subclass for each operation. Recall that :class:`Function` s -are what :mod:`~torch.autograd` uses to compute the results and gradients, and -encode the operation history. Every new function requires you to implement 2 methods: - -- :meth:`~Function.forward` - the code that performs the operation. It can take +:class:`Function` subclass for each operation. Recall that Functions +are what :mod:`~torch.autograd` uses to encode the operation history and compute +gradients. + +When to use +^^^^^^^^^^^ +In general, implement a custom function if you want to perform computations in your model +that are not differentiable or rely on non-Pytorch libraries (e.g., NumPy), but +still wish for your operation to chain with other ops and work with the autograd engine. + +In some situations, custom functions can also be used to improve performance and +memory usage: If you implemented your forward and backward passes using a +`C++ extension `_, +you can wrap them in :class:`~Function` to interface with the autograd +engine. If you'd like to reduce the number of buffers saved for the backward pass, +custom functions can be used to combine ops together. + +When not to use +^^^^^^^^^^^^^^^ +If you can already write your function in terms of PyTorch's built-in ops, its +backward graph is (most likely) already able to be recorded by autograd. In this case, you do +not need to implement the backward function yourself. Consider using a plain +old Python function. + +If you need to maintain state, i.e., trainable parameters, you should (also) use a +custom module. See the section below for more information on extending :mod:`torch.nn`. + +If you'd like to alter the gradients during the backward pass or perform a side +effect, consider registering a +`tensor `_ or +`Module `_ hook. + +How to use +^^^^^^^^^^ +Take the following steps: +1. Subclass :class:`~Function` and implement the :meth:`~Function.forward` and +:meth:`~Function.backward` methods. +2. Call the proper methods on the `ctx` argument. +3. Declare whether your function supports double backward. +4. Validate whether your gradients are correct using gradcheck. + +**Step 1:** After subclassing :class:`Function`, you'll need to define 2 methods: + +- :meth:`~Function.forward` is the code that performs the operation. It can take as many arguments as you want, with some of them being optional, if you specify the default values. All kinds of Python objects are accepted here. :class:`Tensor` arguments that track history (i.e., with ``requires_grad=True``) will be converted to ones that don't track history before the call, and their use will be registered in the graph. Note that this logic won't traverse lists/dicts/any other data structures and will only - consider :class:`Tensor` s that are direct arguments to the call. You can + consider tensors that are direct arguments to the call. You can return either a single :class:`Tensor` output, or a :class:`tuple` of - :class:`Tensor` s if there are multiple outputs. Also, please refer to the + tensors if there are multiple outputs. Also, please refer to the docs of :class:`Function` to find descriptions of useful methods that can be called only from :meth:`~Function.forward`. -- :meth:`~Function.backward` - gradient formula. It will be given +- :meth:`~Function.backward` defines the gradient formula. It will be given as many :class:`Tensor` arguments as there were outputs, with each of them - representing gradient w.r.t. that output. It should return as many - :class:`Tensor` s as there were inputs, with each of them containing the - gradient w.r.t. its corresponding input. If your inputs didn't require - gradient (:attr:`~ctx.needs_input_grad` is a tuple of booleans indicating + representing gradient w.r.t. that output. It is important NEVER to modify + these in-place. It should return as many tensors as there + were inputs, with each of them containing the gradient w.r.t. its + corresponding input. If your inputs didn't require gradient + (:attr:`~ctx.needs_input_grad` is a tuple of booleans indicating whether each input needs gradient computation), or were non-:class:`Tensor` objects, you can return :class:`python:None`. Also, if you have optional arguments to :meth:`~Function.forward` you can return more gradients than there were inputs, as long as they're all :any:`python:None`. -.. note:: - - It's the user's responsibility to use the special functions in the forward's `ctx` - properly in order to ensure that the new :class:`Function` works properly with - the autograd engine. - - - :meth:`~torch.autograd.function._ContextMethodMixin.save_for_backward` must be - used when saving input or output of the forward to be used later in the backward. - - :meth:`~torch.autograd.function._ContextMethodMixin.mark_dirty` must be used to - mark any input that is modified inplace by the forward function. - - :meth:`~torch.autograd.function._ContextMethodMixin.mark_non_differentiable` must - be used to tell the engine if an output is not differentiable. - - :meth:`~torch.autograd.function._ContextMethodMixin.set_materialize_grads` can be - used to tell the autograd engine to optimize gradient computations in the cases where - the output does not depend on the input by not materializing grad tensors given to backward - function. That is, if set to False, None object in python or "undefined tensor" (tensor x for - which x.defined() is False) in C++ will not be converted to a tensor filled with zeros prior - to calling backward. However, supporting this optimization means your custom autograd function - has to handle gradients that are represented in this way and is thus opt-in. Default value is True. - -.. note:: - - By default, all the output Tensors that are of differentiable type will be set to - require gradient and have all autograd metadata set for them. If you don't want - them to require gradients, you can use the `mark_non_differentiable` method mentioned - above. For output Tensors that are not of differentiable type (integer types for example), - they won't be marked as requiring gradients. +**Step 2:** It is your responsibility to use the functions in the forward's `ctx` +properly in order to ensure that the new :class:`Function` works properly with +the autograd engine. + +- :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` must be + used when saving input or output tensors of the forward to be used later in the backward. + Anything else, i.e., non-tensors and tensors that are neither input nor output + should be stored directly on `ctx`. +- :meth:`~torch.autograd.function.FunctionCtx.mark_dirty` must be used to + mark any input that is modified inplace by the forward function. +- :meth:`~torch.autograd.function.FunctionCtx.mark_non_differentiable` must + be used to tell the engine if an output is not differentiable. By + default all output tensors that are of differentiable type will be set + to require gradient. Tensors of non-differentiable type (i.e., integral types) + are never marked as requiring gradients. +- :meth:`~torch.autograd.function.FunctionCtx.set_materialize_grads` can be + used to tell the autograd engine to optimize gradient computations in the cases where + the output does not depend on the input by not materializing grad tensors given to backward + function. That is, if set to False, None object in python or "undefined tensor" (tensor x for + which x.defined() is False) in C++ will not be converted to a tensor filled with zeros prior + to calling backward, and so your code will need to handle such objects as if they were + tensors filled with zeros. The default value of this setting is True. + +**Step 3:** If your :class:`~Function` does not support double backward +you should explicitly declare this by decorating backward with the +:func:`~function.once_differentiable`. With this decorator, attempts to +perform double backward through your function will produce an error. +See our double backward tutorial for more information on double backward. + +**Step 4:** It is recommended that you use :func:`torch.autograd.gradcheck` +to check whether your backward function correctly computes gradients of the +forward by computing the Jacobian matrix using your backward function and +comparing the value element-wise with the Jacobian computed numerically using +finite-differencing. + +Example +^^^^^^^ Below you can find code for a ``Linear`` function from :mod:`torch.nn`, with additional comments:: @@ -151,12 +201,12 @@ And here, we optimize the above example by calling set_materialize_grads(False): return grad_output * ctx.constant, None .. note:: - Inputs to ``backward``, i.e., :attr:`grad_output`, can also be Tensors that + Inputs to ``backward``, i.e., :attr:`grad_output`, can also be tensors that track history. So if ``backward`` is implemented with differentiable operations, (e.g., invocation of another custom :class:`~torch.autograd.function`), higher order derivatives will work. - In this case, the Tensors saved with ``save_for_backward`` can also be used - in the backward and have gradients flowing back but Tensors saved in the ``ctx`` + In this case, the tensors saved with ``save_for_backward`` can also be used + in the backward and have gradients flowing back but tensors saved in the ``ctx`` won't have gradients flowing back for them. If you need gradients to flow back for a Tensor saved in the ``ctx``, you should make it an output of the custom ``Function`` and save it with ``save_for_backward``. diff --git a/docs/source/notes/hip.rst b/docs/source/notes/hip.rst index 20f99cb96c5b0..a9c94e2a4febb 100644 --- a/docs/source/notes/hip.rst +++ b/docs/source/notes/hip.rst @@ -119,6 +119,27 @@ torch.distributed backends Currently, only the "nccl" and "gloo" backends for torch.distributed are supported on ROCm. +.. _cuda-api-to_hip-api-mappings: + +CUDA API to HIP API mappings in C++ +----------------------------------- + +Please refer: https://rocmdocs.amd.com/en/latest/Programming_Guides/HIP_API_Guide.html + +NOTE: The CUDA_VERSION macro, cudaRuntimeGetVersion and cudaDriverGetVersion APIs do not +semantically map to the same values as HIP_VERSION macro, hipRuntimeGetVersion and +hipDriverGetVersion APIs. Please do not use them interchangeably when doing version checks. + +Eg: Instead of +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +If it is desired to not take the code path for ROCm/HIP: +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(USE_ROCM) +If it is desired to take the code path for ROCm/HIP: +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || defined(USE_ROCM) +If it is desired to take the code path for ROCm/HIP only for specific HIP versions: +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 40300) + + Refer to CUDA Semantics doc --------------------------- diff --git a/docs/source/notes/modules.rst b/docs/source/notes/modules.rst index 4eba02231b1ac..c1d978dc78115 100644 --- a/docs/source/notes/modules.rst +++ b/docs/source/notes/modules.rst @@ -117,7 +117,7 @@ multiple modules: Note that :class:`~torch.nn.Sequential` automatically feeds the output of the first ``MyLinear`` module as input into the :class:`~torch.nn.ReLU`, and the output of that as input into the second ``MyLinear`` module. As -shown, it is limited to in-order chaining of modules. +shown, it is limited to in-order chaining of modules with a single input and output. In general, it is recommended to define a custom module for anything beyond the simplest use cases, as this gives full flexibility on how submodules are used for a module's computation. @@ -258,16 +258,32 @@ It's also easy to move all parameters to a different device or change their prec dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64)) : tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=) -These examples show how elaborate neural networks can be formed through module composition. To allow for -quick and easy construction of neural networks with minimal boilerplate, PyTorch provides a large library of -performant modules within the :mod:`torch.nn` namespace that perform computation commonly found within neural -networks, including pooling, convolutions, loss functions, etc. +More generally, an arbitrary function can be applied to a module and its submodules recursively by +using the :func:`~torch.nn.Module.apply` function. For example, to apply custom initialization to parameters +of a module and its submodules: + +.. code-block:: python + + # Define a function to initialize Linear weights. + # Note that no_grad() is used here to avoid tracking this computation in the autograd graph. + @torch.no_grad() + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + m.bias.fill_(0.0) + + # Apply the function recursively on the module and its submodules. + dynamic_net.apply(init_weights) + +These examples show how elaborate neural networks can be formed through module composition and conveniently +manipulated. To allow for quick and easy construction of neural networks with minimal boilerplate, PyTorch +provides a large library of performant modules within the :mod:`torch.nn` namespace that perform common neural +network operations like pooling, convolutions, loss functions, etc. In the next section, we give a full example of training a neural network. For more information, check out: -* Recursively :func:`~torch.nn.Module.apply` a function to a module and its submodules * Library of PyTorch-provided modules: `torch.nn `_ * Defining neural net modules: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_module.html @@ -295,6 +311,12 @@ Optimizers from :mod:`torch.optim`: loss.backward() optimizer.step() + # After training, switch the module to eval mode to do inference, compute performance metrics, etc. + # (see discussion below for a description of training and evaluation modes) + ... + net.eval() + ... + In this simplified example, the network learns to simply output zero, as any non-zero output is "penalized" according to its absolute value by employing :func:`torch.abs` as a loss function. While this is not a very interesting task, the key parts of training are present: @@ -321,6 +343,38 @@ value of ``l1``\ 's ``weight`` parameter shows that its values are now much clos [ 0.0030], [-0.0008]], requires_grad=True) +Note that the above process is done entirely while the network module is in "training mode". Modules default to +training mode and can be switched between training and evaluation modes using :func:`~torch.nn.Module.train` and +:func:`~torch.nn.Module.eval`. They can behave differently depending on which mode they are in. For example, the +:class:`~torch.nn.BatchNorm` module maintains a running mean and variance during training that are not updated +when the module is in evaluation mode. In general, modules should be in training mode during training +and only switched to evaluation mode for inference or evaluation. Below is an example of a custom module +that behaves differently between the two modes: + +.. code-block:: python + + class ModalModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + # Add a constant only in training mode. + return x + 1. + else: + return x + + + m = ModalModule() + x = torch.randn(4) + + print('training mode output: {}'.format(m(x))) + : tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481]) + + m.eval() + print('evaluation mode output: {}'.format(m(x))) + : tensor([ 0.6614, 0.2669, 0.0617, 0.6213, -0.4519]) + Training neural networks can often be tricky. For more information, check out: * Using Optimizers: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_optim.html. @@ -409,12 +463,127 @@ Both persistent and non-persistent buffers are affected by model-wide device / d Buffers of a module can be iterated over using :func:`~torch.nn.Module.buffers` or :func:`~torch.nn.Module.named_buffers`. +.. code-block:: python + + for buffer in m.named_buffers(): + print(buffer) + +The following class demonstrates the various ways of registering parameters and buffers within a module: + +.. code-block:: python + + class StatefulModule(nn.Module): + def __init__(self): + super().__init__() + # Setting a nn.Parameter as an attribute of the module automatically registers the tensor + # as a parameter of the module. + self.param1 = nn.Parameter(torch.randn(2)) + + # Alternative string-based way to register a parameter. + self.register_parameter('param2', nn.Parameter(torch.randn(3))) + + # Reserves the "param3" attribute as a parameter, preventing it from being set to anything + # except a parameter. "None" entries like this will not be present in the module's state_dict. + self.register_parameter('param3', None) + + # Registers a list of parameters. + self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)]) + + # Registers a dictionary of parameters. + self.param_dict = nn.ParameterDict({ + 'foo': nn.Parameter(torch.randn(3)), + 'bar': nn.Parameter(torch.randn(4)) + }) + + # Registers a persistent buffer (one that appears in the module's state_dict). + self.register_buffer('buffer1', torch.randn(4), persistent=True) + + # Registers a non-persistent buffer (one that does not appear in the module's state_dict). + self.register_buffer('buffer2', torch.randn(5), persistent=False) + + # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything + # except a buffer. "None" entries like this will not be present in the module's state_dict. + self.register_buffer('buffer3', None) + + # Adding a submodule registers its parameters as parameters of the module. + self.linear = nn.Linear(2, 3) + + m = StatefulModule() + + # Save and load state_dict. + torch.save(m.state_dict(), 'state.pt') + m_loaded = StatefulModule() + m_loaded.load_state_dict(torch.load('state.pt')) + + # Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do + # not appear in the state_dict. + print(m_loaded.state_dict()) + : OrderedDict([('param1', tensor([-0.0322, 0.9066])), + ('param2', tensor([-0.4472, 0.1409, 0.4852])), + ('buffer1', tensor([ 0.6949, -0.1944, 1.2911, -2.1044])), + ('param_list.0', tensor([ 0.4202, -0.1953])), + ('param_list.1', tensor([ 1.5299, -0.8747])), + ('param_list.2', tensor([-1.6289, 1.4898])), + ('param_dict.bar', tensor([-0.6434, 1.5187, 0.0346, -0.4077])), + ('param_dict.foo', tensor([-0.0845, -1.4324, 0.7022])), + ('linear.weight', tensor([[-0.3915, -0.6176], + [ 0.6062, -0.5992], + [ 0.4452, -0.2843]])), + ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))]) + For more information, check out: * Saving and loading: https://pytorch.org/tutorials/beginner/saving_loading_models.html * Serialization semantics: https://pytorch.org/docs/master/notes/serialization.html * What is a state dict? https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html +Module Initialization +--------------------- + +By default, parameters and floating-point buffers for modules provided by :mod:`torch.nn` are initialized during +module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to +perform well historically for the module type. For certain use cases, it may be desired to initialize with a different +dtype, device (e.g. GPU), or initialization technique. + +Examples: + +.. code-block:: python + + # Initialize module directly onto GPU. + m = nn.Linear(5, 3, device='cuda') + + # Initialize module with 16-bit floating point parameters. + m = nn.Linear(5, 3, dtype=torch.half) + + # Skip default parameter initialization and perform custom (e.g. orthogonal) initialization. + m = torch.nn.utils.skip_init(nn.Linear, 5, 3) + nn.init.orthogonal_(m.weight) + +Note that the device and dtype options demonstrated above also apply to any floating-point buffers registered +for the module: + +.. code-block:: python + + m = nn.BatchNorm2d(3, dtype=torch.half) + print(m.running_mean) + : tensor([0., 0., 0.], dtype=torch.float16) + +While module writers can use any device or dtype to initialize parameters in their custom modules, good practice is +to use ``dtype=torch.float`` and ``device='cpu'`` by default as well. Optionally, you can provide full flexibility +in these areas for your custom module by conforming to the convention demonstrated above that all +:mod:`torch.nn` modules follow: + +* Provide a ``device`` constructor kwarg that applies to any parameters / buffers registered by the module. +* Provide a ``dtype`` constructor kwarg that applies to any parameters / floating-point buffers registered by + the module. +* Only use initialization functions (i.e. functions from :mod:`torch.nn.init`) on parameters and buffers within the + module's constructor. Note that this is only required to use :func:`~torch.nn.utils.skip_init`; see + `this page `_ for an explanation. + +For more information, check out: + +* Skipping module parameter initialization: https://pytorch.org/tutorials/prototype/skip_param_init.html + Module Hooks ------------ @@ -443,16 +612,137 @@ All hooks allow the user to return an updated value that will be used throughout Thus, these hooks can be used to either execute arbitrary code along the regular module forward/backward or modify some inputs/outputs without having to change the module's ``forward()`` function. +Below is an example demonstrating usage of forward and backward hooks: + +.. code-block:: python + + torch.manual_seed(1) + + def forward_pre_hook(m, inputs): + # Allows for examination and modification of the input before the forward pass. + # Note that inputs are always wrapped in a tuple. + input = inputs[0] + return input + 1. + + def forward_hook(m, inputs, output): + # Allows for examination of inputs / outputs and modification of the outputs + # after the forward pass. Note that inputs are always wrapped in a tuple while outputs + # are passed as-is. + + # Residual computation a la ResNet. + return output + inputs[0] + + def backward_hook(m, grad_inputs, grad_outputs): + # Allows for examination of grad_inputs / grad_outputs and modification of + # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and + # grad_outputs are always wrapped in tuples. + new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs] + return new_grad_inputs + + # Create sample module & input. + m = nn.Linear(3, 3) + x = torch.randn(2, 3, requires_grad=True) + + # ==== Demonstrate forward hooks. ==== + # Run input through module before and after adding hooks. + print('output with no forward hooks: {}'.format(m(x))) + : output with no forward hooks: tensor([[-0.5059, -0.8158, 0.2390], + [-0.0043, 0.4724, -0.1714]], grad_fn=) + + # Note that the modified input results in a different output. + forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook) + print('output with forward pre hook: {}'.format(m(x))) + : output with forward pre hook: tensor([[-0.5752, -0.7421, 0.4942], + [-0.0736, 0.5461, 0.0838]], grad_fn=) + + # Note the modified output. + forward_hook_handle = m.register_forward_hook(forward_hook) + print('output with both forward hooks: {}'.format(m(x))) + : output with both forward hooks: tensor([[-1.0980, 0.6396, 0.4666], + [ 0.3634, 0.6538, 1.0256]], grad_fn=) + + # Remove hooks; note that the output here matches the output before adding hooks. + forward_pre_hook_handle.remove() + forward_hook_handle.remove() + print('output after removing forward hooks: {}'.format(m(x))) + : output after removing forward hooks: tensor([[-0.5059, -0.8158, 0.2390], + [-0.0043, 0.4724, -0.1714]], grad_fn=) + + # ==== Demonstrate backward hooks. ==== + m(x).sum().backward() + print('x.grad with no backwards hook: {}'.format(x.grad)) + : x.grad with no backwards hook: tensor([[ 0.4497, -0.5046, 0.3146], + [ 0.4497, -0.5046, 0.3146]]) + + # Clear gradients before running backward pass again. + m.zero_grad() + x.grad.zero_() + + m.register_full_backward_hook(backward_hook) + m(x).sum().backward() + print('x.grad with backwards hook: {}'.format(x.grad)) + : x.grad with backwards hook: tensor([[42., 42., 42.], + [42., 42., 42.]]) + Advanced Features ----------------- PyTorch also provides several more advanced features that are designed to work with modules. All these functionalities -are "inherited" when writing a new module. In-depth discussion of these features can be found in the links below. +are available for custom-written modules, with the small caveat that certain features may require modules to conform +to particular constraints in order to be supported. In-depth discussion of these features and the corresponding +requirements can be found in the links below. -For more information, check out: +Distributed Training +******************** + +Various methods for distributed training exist within PyTorch, both for scaling up training using multiple GPUs +as well as training across multiple machines. Check out the +`distributed training overview page `_ for +detailed information on how to utilize these. + +Profiling Performance +********************* + +The `PyTorch Profiler `_ can be useful for identifying +performance bottlenecks within your models. It measures and outputs performance characteristics for +both memory usage and time spent. + +Improving Performance with Quantization +*************************************** + +Applying quantization techniques to modules can improve performance and memory usage by utilizing lower +bitwidths than floating-point precision. Check out the various PyTorch-provided mechanisms for quantization +`here `_. + +Improving Memory Usage with Pruning +*********************************** + +Large deep learning models are often over-parametrized, resulting in high memory usage. To combat this, PyTorch +provides mechanisms for model pruning, which can help reduce memory usage while maintaining task accuracy. The +`Pruning tutorial `_ describes how to utilize +the pruning techniques PyTorch provides or define custom pruning techniques as necessary. + +Deploying with TorchScript +************************** + +When deploying a model for use in production, the overhead of Python can be unacceptable due to its poor +performance characteristics. For cases like this, +`TorchScript `_ provides a way to load +and run an optimized model program from outside of Python, such as within a C++ program. + +Parametrizations +**************** + +For certain applications, it can be beneficial to constrain the parameter space during model training. For example, +enforcing orthogonality of the learned parameters can improve convergence for RNNs. PyTorch provides a mechanism for +applying `parametrizations `_ such as this, and +further allows for custom constraints to be defined. + +Transforming Modules with FX +**************************** -* Profiling: https://pytorch.org/tutorials/beginner/profiler.html -* Pruning: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html -* Quantization: https://pytorch.org/tutorials/recipes/quantization.html -* Exporting modules to TorchScript (e.g. for usage from C++): - https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html +The `FX `_ component of PyTorch provides a flexible way to transform +modules by operating directly on module computation graphs. This can be used to programmatically generate or +manipulate modules for a broad array of use cases. To explore FX, check out these examples of using FX for +`convolution + batch norm fusion `_ and +`CPU performance analysis `_. diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 2ded57ff87a1b..695f0a2a03f6d 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -210,7 +210,8 @@ algorithms. lr_scheduler.MultiplicativeLR lr_scheduler.StepLR lr_scheduler.MultiStepLR - lr_scheduler.WarmUpLR + lr_scheduler.ConstantLR + lr_scheduler.LinearLR lr_scheduler.ExponentialLR lr_scheduler.CosineAnnealingLR lr_scheduler.ReduceLROnPlateau diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index eb6c74c72facd..a86368ef8d660 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -35,6 +35,13 @@ that perform all or part of the computation in lower precision. Higher-level APIs are provided that incorporate typical workflows of converting FP32 model to lower precision with minimal accuracy loss. +Quantization requires users to be aware of three concepts: + +#. Quantization Config (Qconfig): Specifies how weights and activations are to be quantized. Qconfig is needed to create a quantized model. +#. Backend: Refers to kernels that support quantization, usually with different numerics. +#. Quantization engine (torch.backends.quantization.engine): When a quantized model is executed, the qengine specifies which backend is to be used for execution. It is important to ensure that the qengine is consistent with the Qconfig. + + Natively supported backends --------------------------- @@ -45,7 +52,8 @@ Today, PyTorch supports the following backends for running quantized operators e * ARM CPUs (typically found in mobile/embedded devices), via `qnnpack` (``_). -The corresponding implementation is chosen automatically based on the PyTorch build mode. +The corresponding implementation is chosen automatically based on the PyTorch build mode, though users +have the option to override this by setting `torch.backends.quantization.engine` to `fbgemm` or `qnnpack`. .. note:: @@ -58,7 +66,7 @@ The corresponding implementation is chosen automatically based on the PyTorch bu When preparing a quantized model, it is necessary to ensure that qconfig -and the qengine used for quantized computations match the backend on which +and the engine used for quantized computations match the backend on which the model will be executed. The qconfig controls the type of observers used during the quantization passes. The qengine controls whether `fbgemm` or `qnnpack` specific packing function is used when packing weights for linear @@ -139,16 +147,13 @@ The following table compares the differences between Eager Mode Quantization and +-----------------+-------------------+-------------------+ -Eager Mode Quantization -^^^^^^^^^^^^^^^^^^^^^^^ - -There are three types of quantization supported in Eager Mode Quantization: +There are three types of quantization supported: 1. dynamic quantization (weights quantized with activations read/stored in floating point and quantized for compute.) 2. static quantization (weights quantized, activations quantized, calibration required post training) -3. quantization aware training (weights quantized, activations quantized, +3. static quantization aware training (weights quantized, activations quantized, quantization numerics modeled during training) Please see our `Introduction to Quantization on Pytorch @@ -156,6 +161,40 @@ Please see our `Introduction to Quantization on Pytorch for a more comprehensive overview of the tradeoffs between these quantization types. +Operator coverage varies between dynamic and static quantization and is captured in the table below. +Note that for FX quantization, the corresponding functionals are also supported. + ++---------------------------+-------------------+--------------------+ +| |Static | Dynamic | +| |Quantization | Quantization | ++---------------------------+-------------------+--------------------+ +| | nn.Linear | | Y | | Y | +| | nn.Conv1d/2d/3d | | Y | | N | ++---------------------------+-------------------+--------------------+ +| | nn.LSTM | | N | | Y | +| | nn.GRU | | N | | Y | ++---------------------------+-------------------+--------------------+ +| | nn.RNNCell | | N | | Y | +| | nn.GRUCell | | N | | Y | +| | nn.LSTMCell | | N | | Y | ++---------------------------+-------------------+--------------------+ +|nn.EmbeddingBag | Y (activations | | +| | are in fp32) | Y | ++---------------------------+-------------------+--------------------+ +|nn.Embedding | Y | N | ++---------------------------+-------------------+--------------------+ +|nn.MultiheadAttention |Not Supported | Not supported | ++---------------------------+-------------------+--------------------+ +|Activations |Broadly supported | Un-changed, | +| | | computations | +| | | stay in fp32 | ++---------------------------+-------------------+--------------------+ + + +Eager Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^ + + Dynamic Quantization ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/rpc/distributed_autograd.rst b/docs/source/rpc/distributed_autograd.rst index 61af22b9486f5..71cf1f2fd3178 100644 --- a/docs/source/rpc/distributed_autograd.rst +++ b/docs/source/rpc/distributed_autograd.rst @@ -65,7 +65,7 @@ an RPC. input tensors. The output gradients of this function are sent to the source node to the appropriate ``send`` function during the backward pass. - Each ``send-recv`` pair is assigned a globally unique ``autograd_message_id`` - to uniquely identify the pair. This is useful to lookup the corresponding + to uniquely identify the pair. This is useful to look up the corresponding function on a remote node during the backward pass. - For :ref:`rref`, whenever we call :meth:`torch.distributed.rpc.RRef.to_here` we attach an appropriate ``send-recv`` pair for the tensors involved. @@ -98,7 +98,7 @@ This context serves the following purpose: 2. During the forward pass we store the ``send`` and ``recv`` functions for each autograd pass in this context. This ensures we hold references to the appropriate nodes in the autograd graph to keep it alive. In addition to - this, it is easy to lookup the appropriate ``send`` and ``recv`` functions + this, it is easy to look up the appropriate ``send`` and ``recv`` functions during the backward pass. 3. In general we also use this context to store some metadata for each distributed autograd pass. diff --git a/docs/source/special.rst b/docs/source/special.rst index 06961dbeaaab6..b74d833c96324 100644 --- a/docs/source/special.rst +++ b/docs/source/special.rst @@ -6,10 +6,6 @@ torch.special The torch.special module, modeled after SciPy's `special `_ module. -This module is in BETA. New functions are still being added, and some -functions may change in future PyTorch releases. See the documentation of each -function for details. - .. automodule:: torch.special :noindex: diff --git a/docs/source/testing.rst b/docs/source/testing.rst index 981a636c53390..9f1e2c3c53f89 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -9,3 +9,4 @@ torch.testing .. automodule:: torch.testing .. autofunction:: assert_close +.. autofunction:: make_tensor diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 88cbc6986bf31..5aa5dbc9387b4 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -88,6 +88,7 @@ Indexing, Slicing, Joining, Mutating Ops :nosignatures: cat + concat conj chunk dsplit diff --git a/docs/source/type_info.rst b/docs/source/type_info.rst index fe8eaa1769adf..0647cca544c0f 100644 --- a/docs/source/type_info.rst +++ b/docs/source/type_info.rst @@ -26,13 +26,18 @@ bits int The number of bits occupied by the type. eps float The smallest representable number such that ``1.0 + eps != 1.0``. max float The largest representable number. min float The smallest representable number (typically ``-max``). -tiny float The smallest positive representable number. +tiny float The smallest positive normal number. See notes. resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``. ========== ===== ======================================== .. note:: The constructor of :class:`torch.finfo` can be called without argument, in which case the class is created for the pytorch default dtype (as returned by :func:`torch.get_default_dtype`). +.. note:: + `tiny` returns the smallest *normal* number, but there are smaller + subnormal numbers. See https://en.wikipedia.org/wiki/Denormal_number + for more information. + .. _iinfo-doc: diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 4ee0cdad92ad1..f39d4f0fa5abf 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -79,7 +79,7 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test1* ]]; then fi if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then # Update the loop for new opsets - for i in $(seq 10 13); do + for i in $(seq 10 14); do pytest "${args[@]}" \ "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i" done diff --git a/setup.py b/setup.py index 8135e1e4c2f7c..6d9ed53dc66aa 100644 --- a/setup.py +++ b/setup.py @@ -854,6 +854,7 @@ def make_relative_rpath_args(path): 'console_scripts': [ 'convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx', 'convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2', + 'torchrun = torch.distributed.run:main', ] } @@ -1028,8 +1029,6 @@ def print_box(msg): 'include/THC/*.cuh', 'include/THC/*.h*', 'include/THC/generic/*.h', - 'include/THCUNN/*.cuh', - 'include/THCUNN/generic/*.h', 'include/THH/*.cuh', 'include/THH/*.h*', 'include/THH/generic/*.h', diff --git a/test/ao/sparsity/test_pruner.py b/test/ao/sparsity/test_pruner.py index 8f5f6dd19abbe..55364536b6191 100644 --- a/test/ao/sparsity/test_pruner.py +++ b/test/ao/sparsity/test_pruner.py @@ -4,7 +4,7 @@ import torch from torch import nn -from torch.ao.sparsity import BasePruner, PruningParametrization +from torch.ao.sparsity import BasePruner, PruningParametrization, ZeroesParametrization from torch.nn.utils import parametrize from torch.testing._internal.common_utils import TestCase @@ -13,8 +13,13 @@ DEVICES = {"cpu", "cuda" if torch.cuda.is_available() else "cpu"} +NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed + nn.BatchNorm2d +} + class Linear(nn.Module): + r"""Model with Linear layers, in Sequential and outside, without biases""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -29,6 +34,7 @@ def forward(self, x): class LinearB(nn.Module): + r"""Model with Linear layers, in Sequential and outside, with biases""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -43,6 +49,8 @@ def forward(self, x): class MultipleLinear(nn.Module): + r"""Model with multiple Linear layers, in Sequential and outside, without biases + and with activation functions""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -61,6 +69,8 @@ def forward(self, x): class MultipleLinearB(nn.Module): + r"""Model with multiple Linear layers, in Sequential and outside, with biases + and with activation functions""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -79,6 +89,8 @@ def forward(self, x): class MultipleLinearMixed(nn.Module): + r"""Model with multiple Linear layers, in Sequential and outside, some with biases + and with activation functions""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -97,6 +109,7 @@ def forward(self, x): class Conv2dA(nn.Module): + r"""Model with Conv2d layers, in Sequential and outside, without biases""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -111,6 +124,7 @@ def forward(self, x): class Conv2dB(nn.Module): + r"""Model with Conv2d layers, in Sequential and outside, with biases""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -125,6 +139,7 @@ def forward(self, x): class Conv2dC(nn.Module): + r"""Model with Conv2d layers, in Sequential and outside, with and without biases""" def __init__(self): super().__init__() self.seq = nn.Sequential( @@ -138,6 +153,24 @@ def forward(self, x): return x +class Conv2dBN(nn.Module): + r"""Model with Conv2d layers and BatchNorms""" + def __init__(self): + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(1, 32, 3, 1, bias=True), + nn.BatchNorm2d(32) + ) + self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True) + self.bn = nn.BatchNorm2d(64) + + def forward(self, x): + x = self.seq(x) + x = self.conv2d(x) + x = self.bn(x) + return x + + class SimplePruner(BasePruner): def update_mask(self, layer, **kwargs): layer.parametrizations.weight[0].pruned_outputs.add(1) @@ -150,50 +183,83 @@ def update_mask(self, layer, **kwargs): class TestBasePruner(TestCase): def _check_pruner_prepared(self, model, pruner, device): - for g in pruner.module_groups: - module = g['module'] - assert module.weight.device == device - # Check mask exists - assert hasattr(module, 'mask') - # Check parametrization exists and is correct - assert parametrize.is_parametrized(module) - assert hasattr(module, "parametrizations") - # Assume that this is the 1st/only parametrization - assert type(module.parametrizations.weight[0]) == PruningParametrization - - def _check_pruner_converted(self, model, pruner, device): - for g in pruner.module_groups: - module = g['module'] - assert module.weight.device == device - assert not hasattr(module, "parametrizations") - assert not hasattr(module, 'mask') + for config in pruner.module_groups: + modules = [] + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + for module in modules: + assert module.weight.device == device + # Check mask exists + assert hasattr(module, 'mask') + # Check parametrization exists and is correct + assert parametrize.is_parametrized(module) + assert hasattr(module, "parametrizations") + # Assume that this is the 1st/only parametrization + if isinstance(module, tuple(NEEDS_ZEROS)): + assert type(module.parametrizations.weight[0]) == ZeroesParametrization + else: + assert type(module.parametrizations.weight[0]) == PruningParametrization + + def _check_pruner_mask_squashed(self, model, pruner, device): + for config in pruner.module_groups: + modules = [] + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + for module in modules: + assert module.weight.device == device + assert not hasattr(module, "parametrizations") + assert not hasattr(module, 'mask') def _check_pruner_valid_before_step(self, model, pruner, device): - for g in pruner.module_groups: - module = g['module'] - assert module.weight.device == device - assert module.parametrizations.weight[0].pruned_outputs == set() + for config in pruner.module_groups: + modules = [] + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + for module in modules: + assert module.weight.device == device + assert module.parametrizations.weight[0].pruned_outputs == set() def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device): - for g in pruner.module_groups: - module = g['module'] - assert module.weight.device == device - assert module.parametrizations.weight[0].pruned_outputs == pruned_set + for config in pruner.module_groups: + modules = [] + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + for module in modules: + assert module.weight.device == device + assert module.parametrizations.weight[0].pruned_outputs == pruned_set def _test_constructor_on_device(self, model, device): self.assertRaisesRegex(TypeError, 'with abstract methods update_mask', BasePruner) model = model.to(device) - pruner = SimplePruner(model, None, None) + pruner = SimplePruner(None) + pruner.prepare(model, None) for g in pruner.module_groups: module = g['module'] assert module.weight.device == device assert len(pruner.module_groups) == 2 pruner.step() # Can instantiate the model with configs - pruner = SimplePruner(model, [model.linear], {'test': 3}) + pruner = SimplePruner({'test': 3}) + pruner.prepare(model, [model.linear]) assert len(pruner.module_groups) == 1 - assert pruner.module_groups[0]['path'] == 'linear' + assert pruner.module_groups[0]['fqn'] == 'linear' assert 'test' in pruner.module_groups[0] assert pruner.module_groups[0]['test'] == 3 @@ -205,8 +271,8 @@ def test_constructor(self): def _test_prepare_linear_on_device(self, model, device): model = model.to(device) x = torch.ones(128, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == (128, 16) @@ -216,65 +282,71 @@ def test_prepare_linear(self): for model in models: self._test_prepare_linear_on_device(model, torch.device(device)) - def _test_prepare_conv2d_on_device(self, model, device): + def _test_prepare_conv2d_on_device(self, model, config, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() + pruner = SimplePruner(None) + pruner.prepare(model, config) self._check_pruner_prepared(model, pruner, device) assert model(x).shape == (1, 64, 24, 24) def test_prepare_conv2d(self): - models = [Conv2dA(), Conv2dB(), Conv2dC()] + bn_model = Conv2dBN() + bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)] + + models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] + configs = [None, None, None, bn_config] for device in DEVICES: - for model in models: - self._test_prepare_conv2d_on_device(model, torch.device(device)) + for model, config in zip(models, configs): + self._test_prepare_conv2d_on_device(model, config, torch.device(device)) - def _test_convert_linear_on_device(self, model, device): + def _test_squash_mask_linear_on_device(self, model, device): model = model.to(device) x = torch.ones(128, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.convert() - self._check_pruner_converted(model, pruner, device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + pruner.squash_mask() + self._check_pruner_mask_squashed(model, pruner, device) assert model(x).shape == (128, 16) - def test_convert_linear(self): + def test_squash_mask_linear(self): models = [Linear(), LinearB()] # without and with bias for device in DEVICES: for model in models: - self._test_convert_linear_on_device(model, torch.device(device)) + self._test_squash_mask_linear_on_device(model, torch.device(device)) - def _test_convert_conv2d_on_device(self, model, device): + def _test_squash_mask_conv2d_on_device(self, model, config, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.convert() - self._check_pruner_converted(model, pruner, device) + pruner = SimplePruner(None) + pruner.prepare(model, config) + pruner.squash_mask() + self._check_pruner_mask_squashed(model, pruner, device) assert model(x).shape == (1, 64, 24, 24) - def test_convert_conv2d(self): - models = [Conv2dA(), Conv2dB(), Conv2dC()] + def test_squash_mask_conv2d(self): + bn_model = Conv2dBN() + bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)] + + models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] + configs = [None, None, None, bn_config] for device in DEVICES: - for model in models: - self._test_convert_conv2d_on_device(model, torch.device(device)) + for model, config in zip(models, configs): + self._test_squash_mask_conv2d_on_device(model, config, torch.device(device)) def _test_step_linear_on_device(self, model, is_basic, device): model = model.to(device) if is_basic: x = torch.ones(16, 16) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = SimplePruner(None) + pruner.prepare(model, None) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, {1}, device) else: x = torch.ones(7, 7) - pruner = MultiplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = MultiplePruner(None) + pruner.prepare(model, None) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() self._check_pruner_valid_after_step(model, pruner, {1, 2}, device) @@ -288,19 +360,25 @@ def test_step_linear(self): for model in complex_models: self._test_step_linear_on_device(model, False, torch.device(device)) - def _test_step_conv2d_on_device(self, model, device): + def _test_step_conv2d_on_device(self, model, config, device): model = model.to(device) x = torch.ones((1, 1, 28, 28)) - pruner = SimplePruner(model, None, None) - pruner.prepare() - pruner.enable_mask_update = True + pruner = SimplePruner(None) + pruner.prepare(model, config) self._check_pruner_valid_before_step(model, pruner, device) pruner.step() + if type(model) is Conv2dBN: + assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0]) + assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d) self._check_pruner_valid_after_step(model, pruner, {1}, device) assert model(x).shape == (1, 64, 24, 24) def test_step_conv2d(self): - models = [Conv2dA(), Conv2dB(), Conv2dC()] + bn_model = Conv2dBN() + bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)] + + models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] + configs = [None, None, None, bn_config] for device in DEVICES: - for model in models: - self._test_step_conv2d_on_device(model, torch.device(device)) + for model, config in zip(models, configs): + self._test_step_conv2d_on_device(model, config, torch.device(device)) diff --git a/test/autograd/test_complex.py b/test/autograd/test_complex.py new file mode 100644 index 0000000000000..74fcfdafbce2a --- /dev/null +++ b/test/autograd/test_complex.py @@ -0,0 +1,103 @@ +import torch + +from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck + + +class TestAutogradComplex(TestCase): + def test_view_func_for_complex_views(self): + # case 1: both parent and child have view_func + x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) + y = x.detach().requires_grad_(True) + + x0 = x.clone() + x1 = torch.view_as_complex(x0) + x2 = torch.view_as_real(x1) + x2.mul_(2) + x2.sum().backward() + + y0 = y.clone() + y0.mul_(2) + y0.sum().backward() + + self.assertEqual(x.grad, y.grad) + + # case 2: parent has view_func but child does not + x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) + y = x.detach().requires_grad_(True) + + def fn(a): + b = a.clone() + b1 = torch.view_as_complex(b) + b2 = b1.reshape(b1.numel()) + return b2 + + x0 = fn(x) + x0.mul_(2) + x0.sum().backward() + + y0 = fn(y) + y1 = y0.mul(2) + y1.sum().backward() + + self.assertEqual(x.grad, y.grad) + + # case 3: parent does not have a view_func but child does + x = torch.randn(10, dtype=torch.cdouble, requires_grad=True) + y = x.detach().requires_grad_(True) + + def fn(a, dim0_size=5): + b = a.clone() + b1 = b.reshape(dim0_size, 2) + b2 = torch.view_as_real(b1) + return b2 + + x0 = fn(x) + x0.mul_(2) + x0.sum().backward() + + y0 = fn(y) + y1 = y0.mul(2) + y1.sum().backward() + + self.assertEqual(x.grad, y.grad) + + def test_view_with_multi_output(self): + x = torch.randn(2, 2, 2, dtype=torch.double) + + x1 = torch.view_as_complex(x) + # Taking an invalid view should always be allowed as long as it is not + # modified inplace + res = x1.unbind(0) + + with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): + res[0] += torch.rand(2, requires_grad=True) + + x.requires_grad_(True) + x1 = torch.view_as_complex(x) + # Taking an invalid view should always be allowed as long as it is not + # modified inplace + res = x1.unbind(0) + + with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): + res[0] += torch.rand(2, requires_grad=True) + + def as_identity(self): + # view_as_real and view_as_complex behavior should be like an identity + def func(z): + z_ = torch.view_as_complex(z) + z_select = torch.select(z_, z_.dim() - 1, 0) + z_select_real = torch.view_as_real(z_select) + return z_select_real.sum() + + z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) + gradcheck(func, [z]) + func(z).backward() + + z1 = z.clone().detach().requires_grad_(True) + torch.select(z1, z1.dim() - 2, 0).sum().backward() + + self.assertEqual(z.grad, z1.grad) + + +if __name__ == '__main__': + run_tests() diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index e1dde921f102d..16b415a7368fa 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -36,86 +36,17 @@ # Internal, profiler-specific ops ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)), ("profiler::_record_function_enter", datetime.date(9999, 1, 1)), - ("aten::_qr_helper", datetime.date(2021, 1, 31)), - ("aten::fft", datetime.date(2021, 1, 31)), - ("aten::ifft", datetime.date(2021, 1, 31)), - ("aten::irfft", datetime.date(2021, 1, 31)), - ("aten::rfft", datetime.date(2021, 1, 31)), - ("aten::linalg_svd", datetime.date(2021, 5, 15)), - ("aten::linalg_cholesky.out", datetime.date(2021, 8, 30)), - ("aten::linalg_cholesky_ex", datetime.date(2021, 8, 30)), - ("aten::linalg_cholesky_ex.L", datetime.date(2021, 8, 30)), ("aten::_cholesky_helper", datetime.date(9999, 1, 1)), ("aten::_lstsq_helper", datetime.date(9999, 1, 1)), - ("aten::linalg_lstsq", datetime.date(2021, 5, 1)), - ("aten::_svd_helper", datetime.date(2021, 1, 31)), ("aten::_syevd_helper", datetime.date(9999, 1, 1)), ("aten::_lu_solve_helper", datetime.date(9999, 1, 1)), ("aten::_lu_with_info", datetime.date(9999, 1, 1)), ("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)), - ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), - ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), - ("aten::_cudnn_rnn_backward", datetime.date(2020, 12, 31)), - ("aten::quantile", datetime.date(2021, 1, 31)), - ("aten::nanquantile", datetime.date(2021, 1, 31)), - ("aten::make_dual", datetime.date(2021, 2, 20)), - ("aten::unpack_dual", datetime.date(2021, 2, 20)), - ("aten::_fft_with_size", datetime.date(2021, 1, 31)), - ("aten::thnn_conv_depthwise2d_backward", datetime.date(2021, 1, 31)), - ("aten::slow_conv3d_backward", datetime.date(2021, 1, 31)), - ("aten::thnn_conv2d_backward", datetime.date(2021, 1, 31)), - ("aten::slow_conv_transpose3d_backward", datetime.date(2021, 1, 31)), - ("aten::slow_conv_transpose2d_backward", datetime.date(2021, 1, 31)), - ("aten::set_", datetime.date(2021, 1, 31)), - ("aten::native_layer_norm", datetime.date(2021, 1, 31)), - ("aten::native_layer_norm_backward", datetime.date(2021, 1, 31)), - ("aten::elu_backward", datetime.date(2021, 1, 31)), - ("aten::_multinomial_alias_setup", datetime.date(2021, 1, 31)), - ("aten::_multinomial_alias_draw", datetime.date(2021, 1, 31)), - ("prim::profile_optional", datetime.date(2021, 1, 31)), - ("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)), - ("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)), ("aten::rowwise_prune", datetime.date(9999, 1, 1)), - ("aten::_mode*", datetime.date(2021, 5, 2)), - ("aten::linalg_multi_dot", datetime.date(2021, 3, 25)), - ("aten::coalesce", datetime.date(2021, 4, 15)), - ("aten::empty_meta", datetime.date(2021, 4, 1)), - ("aten::div", datetime.date(2021, 4, 28)), - ("aten::divide", datetime.date(2021, 4, 28)), - ("aten::_var", datetime.date(2021, 5, 28)), - ("aten::_std", datetime.date(2021, 5, 28)), - ("aten::batch_norm_backward_elemt", datetime.date(2021, 5, 1)), - ("aten::assert_async", datetime.date(2021, 5, 1)), - ("aten::cumprod_backward", datetime.date(2021, 5, 1)), ("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)), - ("aten::_addmv_impl_", datetime.date(2021, 5, 15)), ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)), ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)), - ("aten::_amp_update_scale", datetime.date(2021, 6, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), - ("aten::linalg_vector_norm", datetime.date(2021, 5, 15)), - ("aten::repeat_interleave", datetime.date(2021, 6, 26)), - ("aten::one_hot", datetime.date(2021, 6, 15)), - ("aten::slice", datetime.date(2021, 6, 30)), - ("aten::conj", datetime.date(2021, 8, 1)), - ("aten::_conj", datetime.date(2021, 8, 1)), - ("aten::conj.out", datetime.date(2021, 8, 1)), - ("aten::segment_reduce_backward", datetime.date(2021, 6, 15)), - ("aten::segment_reduce", datetime.date(2021, 8, 26)), - ("aten::_segment_reduce_backward", datetime.date(2021, 8, 26)), - ("aten::thnn_conv_depthwise2d", datetime.date(2021, 8, 27)), - ("aten::thnn_conv_depthwise2d.out", datetime.date(2021, 8, 27)), - ("aten::thnn_conv_depthwise2d_forward", datetime.date(2021, 8, 27)), - ("aten::thnn_conv_depthwise2d_forward.out", datetime.date(2021, 8, 27)), - ("aten::thnn_conv_depthwise2d_backward", datetime.date(2021, 8, 27)), - ("aten::thnn_conv_depthwise2d_backward.out", datetime.date(2021, 8, 27)), - ("aten::_view_as_real_physical", datetime.date(2021, 8, 27)), - ("aten::_view_as_real_physical", datetime.date(2021, 8, 1)), - ("aten::_bmm", datetime.date(2021, 8, 14)), - ("aten::_bmm.out", datetime.date(2021, 8, 14)), - ("aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", datetime.date(2021, 8, 15)), - ("aten::_cumsum", datetime.date(2021, 8, 31)), - ("aten::_cumprod", datetime.date(2021, 8, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 9bd9d6780fe7d..fc21afaef6a8a 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -41,6 +41,10 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/grad_mode.cpp ) +if(USE_DEPLOY) + list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/imethod.cpp) +endif() + if(USE_CUDA) list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp) endif() @@ -59,6 +63,10 @@ if(USE_CUDA) target_compile_definitions(test_api PRIVATE "USE_CUDA") endif() +if(USE_DEPLOY) + target_link_libraries(test_api PRIVATE torch_deploy) +endif() + # Workaround for https://github.com/pytorch/pytorch/issues/40941 if(USE_OPENMP AND CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0.0)) # Compiling transformer.cpp or pow_test.cpp with -O2+ and both -fuse-openmp and -faligned-newout any optimization diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 80d892d5195c9..edb73f90852a2 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -1,6 +1,8 @@ #include +#include #include +#include #include @@ -869,6 +871,261 @@ TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) { } } +/** + * Tests for AutogradNotImplementedFallback + * - Check that we created the NotImplemented kernel when inputs require grad + * but when no inputs require grad, we should not create this node + * - check_inplace logic + * - view ops (TODO: not an official view yet, update this once InplaceOrView kernel is landed) + * - TODO: Tests for NDEBUG checks? + * - tensorlist input and output + * - multiple outputs / non-tensor output + * - rebase_history vs set_history + */ +namespace { + +torch::Tensor inplace_op(const torch::Tensor& self, const torch::Tensor& other) { + return self.add_(other); +} + +std::tuple two_arg_inplace_op(const torch::Tensor& self, const torch::Tensor& other) { + other.add_(self); + self.add_(other); + return std::tuple(self, other); +} + +std::tuple two_pairs_of_view_op(const torch::Tensor& self, const torch::Tensor& other) { + // This is not allowed. We test below that this calling into the boxed kernel will raise an error + auto self_view = self.view(-1); + auto other_view = other.view(-1); + return std::tuple(self_view, other_view); +} + +int64_t ret_single_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { + return 12; +} + +torch::Tensor opt_op(const torch::Tensor& self, const c10::optional& other) { + if (other.has_value()) { + return self + other.value(); + } else { + return self.clone(); + } +} + +torch::Tensor my_custom_op(const torch::Tensor& self, const torch::Tensor& other) { + return self + other; +} + +std::tuple ret_tuple_non_tensor(const torch::Tensor& self, const torch::Tensor& other) { + auto a = self - other; + auto b = self + other; + return std::tuple(a, b, 12); +} + +torch::Tensor view_op(const torch::Tensor& self, const torch::Tensor& other) { + return self.view(-1); +} + +std::vector ret_tensor_vector(const torch::Tensor& self, const torch::Tensor& other) { + std::vector out; + out.push_back(self + other); + out.push_back(self - other); + return out; +} + +torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) { + const auto& res = self.clone(); + for (const auto& t : other) { + res.add_(t); + } + return res; +} + +#define REGISTER_TEST_OP(name, schema, fn) \ + auto m = MAKE_TORCH_LIBRARY(_test); \ + m.def(schema); \ + auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \ + auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \ + m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \ + m_autograd.impl(name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); + +template +void assertBasicChecks(F op) { + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + auto c = torch::tensor({1.}, {torch::kFloat32}); + + // If any inputs require grad, + auto out1 = op(a, b); + ASSERT_THROWS_WITH(out1.backward(), "is not implemented"); + + // # Should not have grad_fn if none require grad + auto out2 = op(b, c); + ASSERT_THROWS_WITH(out2.backward(), "element 0 of tensors does not require grad and does not have a grad_fn"); + + // TODO: Forward AD Tests? +} + +} // namespace + +TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) { + REGISTER_TEST_OP("ret_single_non_tensor", "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int", ret_single_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_single_non_tensor", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b)); +} + +TEST(TestAutogradNotImplementedFallback, DoubleViewOP) { + REGISTER_TEST_OP("two_pairs_of_view_op", "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))", two_pairs_of_view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_pairs_of_view_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + }; + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + ASSERT_THROWS_WITH(op(a, b), + "Expected only a single output in the operator schema to have a non-write alias annotation"); +} + +TEST(TestAutogradNotImplementedFallback, InplaceOp) { + REGISTER_TEST_OP("inplace_op", "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)", inplace_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + // Check in-place + ASSERT_THROWS_WITH(op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); + op(b, a); + a = a.clone(); + b = b.clone(); + auto c = op(a, b); + ASSERT_TRUE(torch::allclose(c, inplace_op(a, b))); + + // Test in-place on view + auto base = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone(); + auto view = base.view(-1); + auto t = torch::tensor({1.}, {torch::kFloat32}); + + torch::Tensor v_nograd; + { + c10::NoGradGuard guard; + v_nograd = base.view(-1); + op(v_nograd, t); + } + + ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode"); + ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl()); + + // TODO: once we have InplaceOrView kernel, renable this since version counter would actually + // be incremented + // ASSERT_THAT(op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward")); +} + +TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) { + REGISTER_TEST_OP("two_arg_inplace_op", "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))", two_arg_inplace_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::two_arg_inplace_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + }; + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + // Both are modified in-place! + ASSERT_THROWS_WITH(op(a, b), + "a leaf Variable that requires grad is being used in an in-place operation"); + ASSERT_THROWS_WITH(op(b, a), + "a leaf Variable that requires grad is being used in an in-place operation"); +} + +TEST(TestAutogradNotImplementedFallback, OptOp) { + REGISTER_TEST_OP("opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", ""); + auto op = [&](const torch::Tensor& _1, const c10::optional& _2) { + return callOpUnboxed&>(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + auto b = torch::tensor({1.}, {torch::kFloat32}); + + ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b))); + ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {}))); +} + +TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) { + REGISTER_TEST_OP("my_custom_op", "_test::my_custom_op(Tensor self, Tensor other) -> Tensor", my_custom_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) { + REGISTER_TEST_OP("ret_tuple_non_tensor", "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)", ret_tuple_non_tensor); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tuple_non_tensor", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + torch::Tensor out0; + torch::Tensor out1; + int64_t out2; + auto out = callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2); + std::tie(out0, out1, out2) = std::move(out); + return out0; + }; + + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, ViewOp) { + REGISTER_TEST_OP("view_op", "_test::view_op(Tensor(a) self, Tensor other) -> Tensor(a)", view_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, RetTensorVector) { + REGISTER_TEST_OP("ret_tensor_vector", "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]", ret_tensor_vector); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::ret_tensor_vector", ""); + auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) { + return callOpUnboxed, const torch::Tensor&, const torch::Tensor&>(opHandle, _1, _2)[0]; + }; + assertBasicChecks(op); +} + +TEST(TestAutogradNotImplementedFallback, TensorlistOp) { + REGISTER_TEST_OP("tensorlist_op", "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor", tensorlist_op); + auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow("_test::tensorlist_op", ""); + auto op = [&](torch::Tensor _1, at::TensorList _2) { + return callOpUnboxed(opHandle, _1, _2); + }; + + auto a = torch::tensor({1.}, {torch::kFloat32}); + auto b = torch::tensor({1.}, {torch::kFloat32}); + auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true); + std::vector vec = {b, c}; + auto out = op(a, vec); + + ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[0]}), "One of the differentiated Tensors does not require grad"); + ASSERT_THROWS_WITH(torch::autograd::grad({out}, {vec[1]}), "is not implemented"); + + ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec))); +} + + // TODO add these tests if needed // test_once_differentiable // test_sparse_backward diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 582b1eebdb784..8b7889f1841ef 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -792,6 +792,20 @@ TEST_F(FunctionalTest, CrossEntropy) { ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04)); + + // label smoothing with class indices + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat); + output = F::cross_entropy( + input, target, F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(torch::kMean)); + expected = torch::tensor(0.3326, torch::kFloat); + ASSERT_TRUE(output.allclose(expected, 1e-04)); + + // label smoothing with target probabilities + target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); + output = F::cross_entropy( + input, target, F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(torch::kMean)); + expected = torch::tensor(0.5701, torch::kFloat); + ASSERT_TRUE(output.allclose(expected, 1e-04)); } TEST_F(FunctionalTest, MaxUnpool1d) { @@ -1034,17 +1048,19 @@ TEST_F(FunctionalTest, LeakyReLU) { const auto size = 3; for (const auto negative_slope : {0.0, 0.42, 1.0}) { for (const auto inplace : {false, true}) { - auto x = torch::linspace(-10.0, 10.0, size * size * size); - x.resize_({size, size, size}); - auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; - auto y = F::leaky_relu(x, F::LeakyReLUFuncOptions() - .negative_slope(negative_slope).inplace(inplace)); + for (const auto type : {torch::kFloat, torch::kBFloat16}) { + auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); + x.resize_({size, size, size}); + auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; + auto y = F::leaky_relu(x, F::LeakyReLUFuncOptions() + .negative_slope(negative_slope).inplace(inplace)); - ASSERT_EQ(y.ndimension(), 3); - ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - ASSERT_TRUE(torch::allclose(y, y_exp)); - if (inplace) { - ASSERT_TRUE(torch::allclose(x, y_exp)); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), std::vector({size, size, size})); + ASSERT_TRUE(torch::allclose(y, y_exp)); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y_exp)); + } } } } @@ -1443,19 +1459,21 @@ TEST_F(FunctionalTest, RReLU) { for (const auto lower : {0.01, 0.1, 0.2}) { for (const auto upper : {0.3, 0.4, 0.5}) { for (const auto inplace : {false, true}) { - auto x = torch::linspace(-10.0, 10.0, size * size * size); - x.resize_({size, size, size}); - auto x_copy = x.clone(); - auto y = F::rrelu(x, F::RReLUFuncOptions().lower(lower) - .upper(upper).inplace(inplace)); - auto z = ((x_copy >= 0) * (x_copy == y) + - (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; - - ASSERT_EQ(y.ndimension(), 3); - ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); - if (inplace) { - ASSERT_TRUE(torch::allclose(x, y)); + for (const auto type : {torch::kFloat, torch::kBFloat16}) { + auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); + x.resize_({size, size, size}); + auto x_copy = x.clone(); + auto y = F::rrelu(x, F::RReLUFuncOptions().lower(lower) + .upper(upper).inplace(inplace)); + auto z = ((x_copy >= 0) * (x_copy == y) + + (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), std::vector({size, size, size})); + ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y)); + } } } } @@ -1467,16 +1485,18 @@ TEST_F(FunctionalTest, RReLUDefaultOptions) { const auto size = 3; const auto lower = 1.0 / 8.0; const auto upper = 1.0 / 3.0; - auto x = torch::linspace(-10.0, 10.0, size * size * size); - x.resize_({size, size, size}); - auto x_copy = x.clone(); - auto y = F::rrelu(x); - auto z = ((x_copy >= 0) * (x_copy == y) + - (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; + for (const auto type : {torch::kFloat, torch::kBFloat16}) { + auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); + x.resize_({size, size, size}); + auto x_copy = x.clone(); + auto y = F::rrelu(x); + auto z = ((x_copy >= 0) * (x_copy == y) + + (x_copy < 0) * (y >= x_copy * upper) * (y <= lower * x_copy)) * 1.0; - ASSERT_EQ(y.ndimension(), 3); - ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), std::vector({size, size, size})); + ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); + } } TEST_F(FunctionalTest, CELU) { diff --git a/test/cpp/api/imethod.cpp b/test/cpp/api/imethod.cpp index 3349d1b3a8a45..b8c12c649fd19 100644 --- a/test/cpp/api/imethod.cpp +++ b/test/cpp/api/imethod.cpp @@ -8,30 +8,43 @@ using namespace ::testing; using namespace caffe2; -// TODO(T96218435): Enable the following tests in OSS. +const char* simple = "torch/csrc/deploy/example/generated/simple"; +const char* simpleJit = "torch/csrc/deploy/example/generated/simple_jit"; + +// TODO(jwtan): Try unifying cmake and buck for getting the path. +const char* path(const char* envname, const char* path) { + const char* env = getenv(envname); + return env ? env : path; +} + +// Run `python torch/csrc/deploy/example/generate_examples.py` before running the following tests. +// TODO(jwtan): Figure out a way to automate the above step for development. (CI has it already.) TEST(IMethodTest, CallMethod) { - auto script_model = torch::jit::load(getenv("SIMPLE_JIT")); - auto script_method = script_model.get_method("forward"); + auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit)); + auto scriptMethod = scriptModel.get_method("forward"); torch::deploy::InterpreterManager manager(3); - torch::deploy::Package p = manager.load_package(getenv("SIMPLE")); - auto py_model = p.load_pickle("model", "model.pkl"); - torch::deploy::PythonMethodWrapper py_method(py_model, "forward"); + torch::deploy::Package package = manager.load_package(path("SIMPLE", simple)); + auto pyModel = package.load_pickle("model", "model.pkl"); + torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward"); + + EXPECT_EQ(scriptMethod.name(), "forward"); + EXPECT_EQ(pyMethod.name(), "forward"); auto input = torch::ones({10, 20}); - auto output_py = py_method({input}); - auto output_script = script_method({input}); - EXPECT_TRUE(output_py.isTensor()); - EXPECT_TRUE(output_script.isTensor()); - auto output_py_tensor = output_py.toTensor(); - auto output_script_tensor = output_script.toTensor(); - - EXPECT_TRUE(output_py_tensor.equal(output_script_tensor)); - EXPECT_EQ(output_py_tensor.numel(), 200); + auto outputPy = pyMethod({input}); + auto outputScript = scriptMethod({input}); + EXPECT_TRUE(outputPy.isTensor()); + EXPECT_TRUE(outputScript.isTensor()); + auto outputPyTensor = outputPy.toTensor(); + auto outputScriptTensor = outputScript.toTensor(); + + EXPECT_TRUE(outputPyTensor.equal(outputScriptTensor)); + EXPECT_EQ(outputPyTensor.numel(), 200); } TEST(IMethodTest, GetArgumentNames) { - auto scriptModel = torch::jit::load(getenv("SIMPLE_JIT")); + auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit)); auto scriptMethod = scriptModel.get_method("forward"); auto& scriptNames = scriptMethod.getArgumentNames(); @@ -39,7 +52,7 @@ TEST(IMethodTest, GetArgumentNames) { EXPECT_STREQ(scriptNames[0].c_str(), "input"); torch::deploy::InterpreterManager manager(3); - torch::deploy::Package package = manager.load_package(getenv("SIMPLE")); + torch::deploy::Package package = manager.load_package(path("SIMPLE", simple)); auto pyModel = package.load_pickle("model", "model.pkl"); torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward"); diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 036ff5e4bf2ec..927d884709200 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2315,6 +2315,31 @@ TEST_F(ModulesTest, CrossEntropyLoss) { ASSERT_TRUE( CrossEntropyLoss(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)) ->forward(input, target).allclose(expected, 1e-04)); + + // label smoothing with class indices + loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean)); + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + target = torch::tensor({0, 1}, torch::kLong); + output = loss->forward(input, target); + expected = torch::tensor(0.3326, torch::kFloat); + s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input.sizes(), input.grad().sizes()); + + // label smoothing with with target probabilities + loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean)); + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); + output = loss->forward(input, target); + expected = torch::tensor(0.5701, torch::kFloat); + s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input.sizes(), input.grad().sizes()); + } TEST_F(ModulesTest, CosineSimilarity) { @@ -2521,25 +2546,27 @@ TEST_F(ModulesTest, LeakyReLU) { const auto size = 3; for (const auto inplace : {false, true}) { for (const auto negative_slope : {0.0, 0.42, 1.0}) { - LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)}; - auto x = torch::linspace(-10.0, 10.0, size * size * size); - x.resize_({size, size, size}); - if (!inplace) { - x.requires_grad_(true); - } - auto x_orig = x.clone(); - auto y = model(x); - torch::Tensor s = y.sum(); + for (const auto type : {torch::kFloat, torch::kBFloat16}) { + LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope).inplace(inplace)}; + auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); + x.resize_({size, size, size}); + if (!inplace) { + x.requires_grad_(true); + } + auto x_orig = x.clone(); + auto y = model(x); + torch::Tensor s = y.sum(); - ASSERT_EQ(s.ndimension(), 0); - ASSERT_EQ(y.ndimension(), 3); - ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig; - ASSERT_TRUE(torch::allclose(y, y_exp)); - if (inplace) { - ASSERT_TRUE(torch::allclose(x, y_exp)); - } else { - s.backward(); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), std::vector({size, size, size})); + auto y_exp = (x_orig < 0) * x_orig * negative_slope + (x_orig >= 0) * x_orig; + ASSERT_TRUE(torch::allclose(y, y_exp)); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y_exp)); + } else { + s.backward(); + } } } } @@ -2740,26 +2767,28 @@ TEST_F(ModulesTest, RReLU) { for (const auto lower : {0.01, 0.1, 0.2}) { for (const auto upper : {0.3, 0.4, 0.5}) { for (const auto inplace : {false, true}) { - RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)}; - auto x = torch::linspace(-10.0, 10.0, size * size * size); - x.resize_({size, size, size}); - if (!inplace) { - x.requires_grad_(true); - } - auto x_orig = x.clone(); - auto y = model(x); - torch::Tensor s = y.sum(); - - ASSERT_EQ(s.ndimension(), 0); - ASSERT_EQ(y.ndimension(), 3); - ASSERT_EQ(y.sizes(), std::vector({size, size, size})); - auto z = ((x_orig >= 0) * (x_orig == y) + - (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0; - ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); - if (inplace) { - ASSERT_TRUE(torch::allclose(x, y)); - } else { - s.backward(); + for (const auto type : {torch::kFloat, torch::kBFloat16}) { + RReLU model {RReLUOptions().lower(lower).upper(upper).inplace(inplace)}; + auto x = torch::linspace(-10.0, 10.0, size * size * size).to(type); + x.resize_({size, size, size}); + if (!inplace) { + x.requires_grad_(true); + } + auto x_orig = x.clone(); + auto y = model(x); + torch::Tensor s = y.sum(); + + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), std::vector({size, size, size})); + auto z = ((x_orig >= 0) * (x_orig == y) + + (x_orig < 0) * (y >= x_orig * upper) * (y <= lower * x_orig)) * 1.0; + ASSERT_TRUE(torch::allclose(z, torch::ones_like(z))); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y)); + } else { + s.backward(); + } } } } diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 08115433312f5..8bd37a1fb8a59 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -62,9 +62,11 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_qualified_name.cpp ${JIT_TEST_ROOT}/test_save_load.cpp ${JIT_TEST_ROOT}/test_schema_matching.cpp + ${JIT_TEST_ROOT}/test_stack_opt.cpp ${JIT_TEST_ROOT}/test_subgraph_matcher.cpp ${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp ${JIT_TEST_ROOT}/test_subgraph_utils.cpp + ${JIT_TEST_ROOT}/test_union.cpp ${JIT_TEST_ROOT}/test_utils.cpp ${JIT_TEST_ROOT}/test_script_profile.cpp ${JIT_TEST_ROOT}/test_jit_logging_levels.cpp diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index 1bd556a8980b7..c92cb4da46dde 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1,11 +1,11 @@ #include #include +#include +#include #include -#include "torch/csrc/jit/frontend/ir_emitter.h" -#include "torch/csrc/jit/ir/alias_analysis.h" -#include "torch/csrc/jit/runtime/custom_operator.h" -#include "torch/csrc/utils/memory.h" +#include +#include namespace torch { namespace jit { @@ -484,7 +484,7 @@ TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) { TEST(WriteTrackingTest, Basic) { RegisterOperators reg({Operator( "prim::creates_alias(Tensor(a) x) -> Tensor(a)", - [](Stack* s) {}, + [](Stack&) {}, aliasAnalysisFromSchema())}); const auto creates_alias = Symbol::fromQualString("prim::creates_alias"); auto graph = std::make_shared(); @@ -660,6 +660,31 @@ TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) { } } +TEST(ContainerAliasingTest, UnionAliasing) { + auto graph = std::make_shared(); + parseIR( + R"IR( + graph(%a : Dict(str, Tensor), + %b : Tensor[], + %c : Union(Dict(str, Tensor), Tensor[])): + return (%a, %b, %c) + )IR", + &*graph); + + AliasDb aliasDb(graph); + auto a = graph->outputs().at(0); + auto b = graph->outputs().at(1); + auto c = graph->outputs().at(2); + + EXPECT_TRUE(aliasDb.mayAlias(a, c)); + EXPECT_TRUE(aliasDb.mayAlias(b, c)); + EXPECT_TRUE(aliasDb.mayAlias(c, c)); + EXPECT_FALSE(aliasDb.mayAlias(a, b)); + EXPECT_TRUE(aliasDb.mayContainAlias(a, b)); + EXPECT_TRUE(aliasDb.mayContainAlias(a, c)); + EXPECT_TRUE(aliasDb.mayContainAlias(b, c)); +} + TEST(ContainerAliasingTest, InputsCanAliasOutputs) { // Test input aliasing auto graph = std::make_shared(); @@ -949,11 +974,11 @@ TEST(WildcardsTest, Basic) { RegisterOperators reg( {Operator( "prim::returns_wildcard(Tensor a) -> Tensor(*)", - [](Stack* stack) {}, + [](Stack&) {}, aliasAnalysisFromSchema()), Operator( "prim::writes(Tensor(z!) a) -> Tensor(a)", - [](Stack* stack) {}, + [](Stack&) {}, aliasAnalysisFromSchema())}); const auto returns_wildcard = Symbol::fromQualString("prim::returns_wildcard"); diff --git a/test/cpp/jit/test_concat_opt.cpp b/test/cpp/jit/test_concat_opt.cpp index 03c0ce6a58dae..5cb73d234927e 100644 --- a/test/cpp/jit/test_concat_opt.cpp +++ b/test/cpp/jit/test_concat_opt.cpp @@ -1,45 +1,15 @@ #include +#include #include #include +#include #include #include namespace torch { namespace jit { -namespace { - -void checkOutputs( - const std::vector& out1, - const std::vector& out2) { - ASSERT_EQ(out1.size(), out2.size()); - for (size_t i = 0; i < out1.size(); ++i) { - ASSERT_EQ(out1[i].sizes(), out2[i].sizes()); - float max_diff = (out1[i] - out2[i]).abs().max().item(); - ASSERT_EQ(max_diff, 0); - } -} - -std::vector runGraph( - std::shared_ptr graph, - const std::vector inputs) { - std::vector stack = fmap(inputs); - Code code(graph, "test"); - InterpreterState(code).run(stack); - TORCH_INTERNAL_ASSERT(!stack.empty()); - // Graph outputs that are handled below: - // * A list of Tensors. - // * 1 Tensor. - if (stack.front().isTensorList()) { - return stack.front().toTensorVector(); - } - TORCH_INTERNAL_ASSERT(stack.front().isTensor()); - return {stack.front().toTensor()}; -} - -} // namespace - TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) { auto graph = std::make_shared(); @@ -64,7 +34,7 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationPrefix) { ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // Graph after EliminateConcatCommonInputs: // graph(%0 : ..., @@ -109,7 +79,7 @@ TEST(ConcatOptTest, SimpleCommonInputsEliminationSuffix) { ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // Graph after EliminateConcatCommonInputs: // graph(%0 : ..., @@ -161,7 +131,7 @@ TEST(ConcatOptTest, CommonInputsEliminationWithDifferentOrderInputs) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the inputs // to the `cat` are in different order. @@ -198,7 +168,7 @@ TEST(ConcatOptTest, MoreCommonInputsElimination) { ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); testing::FileCheck() .check_count("= prim::VarConcat(%0, %1, %5)", 1, /*exactly*/ true) @@ -233,7 +203,7 @@ TEST(ConcatOptTest, ExpandConcat) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After full concat optimization we should have the following graph: // @@ -289,7 +259,7 @@ TEST(ConcatOptTest, ConcatWithoutResultShape) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the output // shape of `aten::cat` is not known. @@ -324,7 +294,7 @@ TEST(ConcatOptTest, ConcatWithoutInputShape) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No optimizations should have happened in this case since the shape of %5, // which is an input to `aten::cat`, is not known. @@ -361,7 +331,7 @@ TEST(ConcatOptTest, UseVariadicCat) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::cat` with `prim::VarConcat` we should have the // following graph: @@ -406,7 +376,7 @@ TEST(OptimizeConcatTest, UseVariadicCatReplaceMultiple) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After full concat optimization we should have the following graph: // @@ -446,7 +416,7 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListUses) { graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After replacing `aten::cat` with `prim::VarConcat` we should have the // following graph: @@ -488,7 +458,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationAfterCat) { ASSERT_TRUE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The input list to `aten::cat` is mutated only after `aten::cat` op. So, // it should have been replaced with `prim::VarConcat`. The transformed graph @@ -534,7 +504,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) { ASSERT_FALSE(UseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // No transformation should have happened since the `prim::ListConstruct` is // mutated before `aten::cat`. @@ -549,7 +519,7 @@ TEST(ConcatOptTest, UseVariadicCatWithListMutationBeforeCat) { ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // The mutation of the list must be removed and the `aten::cat` op must // be replaced with the `prim::VarConcat` op in the graph. The transformed @@ -602,7 +572,7 @@ TEST(ConcatOptTest, UseVariadicCatWithMultipleListMutations) { ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // All the mutations of the list must be removed and the `aten::cat` ops must // be replaced with `prim::VarConcat` ops in the graph. The transformed graph @@ -659,7 +629,7 @@ TEST( ASSERT_TRUE(EliminateConcatCommonInputs(graph)); graph->lint(); auto opt_outputs = runGraph(graph, inputs); - checkOutputs(orig_outputs, opt_outputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); // After performing: // * Remove list mutation diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index a34ca33672c7b..39be82ea23430 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -31,7 +31,7 @@ TEST(CustomOperatorTest, InferredSchema) { Stack stack; push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); + op->getOperation()(stack); at::Tensor output; pop(stack, output); @@ -61,7 +61,7 @@ TEST(CustomOperatorTest, ExplicitSchema) { Stack stack; push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); + op->getOperation()(stack); at::Tensor output; pop(stack, output); @@ -109,7 +109,7 @@ TEST(CustomOperatorTest, ListParameters) { c10::List>( {c10::complex(2.4, -5.5), c10::complex(-1.3, 2)})); push(stack, c10::List({at::ones(5)})); - op->getOperation()(&stack); + op->getOperation()(stack); c10::List output; pop(stack, output); @@ -140,7 +140,7 @@ TEST(CustomOperatorTest, ListParameters2) { Stack stack; push(stack, c10::List({at::ones(5)})); - op->getOperation()(&stack); + op->getOperation()(stack); c10::List output; pop(stack, output); @@ -204,7 +204,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) { torch::jit::RegisterOperators reg({OperatorGenerator( TORCH_SELECTIVE_NAME_IN_SCHEMA( op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; at::Tensor b; @@ -223,7 +223,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { torch::jit::RegisterOperators reg({OperatorGenerator( TORCH_SELECTIVE_NAME_IN_SCHEMA( op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; at::Tensor b; @@ -249,7 +249,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { Stack stack; push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); + op->getOperation()(stack); at::Tensor output; pop(stack, output); diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index a2418918336c5..bfdc1f3a0cb7e 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -175,6 +175,15 @@ TEST(InterpreterTest, IgnorableArgsInSchema) { ASSERT_TRUE(op_to_specified_args_non_const["aten::conv2d"] == 6); } +TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) { + auto graph = build_mobile_export_with_out(); + MobileCode function(graph, ""); + auto op_to_specified_args = function.op_to_num_specified_args(); + ASSERT_TRUE(op_to_specified_args.size() == 1); + // this should be 3 when the add_out flag is set to True + ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 4); +} + TEST(InterpreterTest, runAsyncBasicTest) { /* TODO: there are some problem with C++ parsing script program involving diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 3bd2becd8779d..b362c8a6ddb06 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -456,144 +456,6 @@ TEST(LiteInterpreterTest, BuiltinFunction) { } #if !defined FB_XPLAT_BUILD -TEST(LiteInterpreterTest, ModuleInfoBasic) { - Module m("M"); - m.define(R"JIT( - def forward(self, x): - return 2 * x - )JIT"); - - std::stringstream ss; - m._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::unordered_set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(M)::.aten::mul")); -} - -TEST(LiteInterpreterTest, NotSaveModuleInfo) { - Module m("M"); - m.define(R"JIT( - def forward(self, x): - return x + 5 - )JIT"); - - std::stringstream ss; - m._save_for_mobile(ss); - mobile::Module bc = _load_for_mobile(ss); - - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - AT_ASSERT( - module_info.empty() || - (module_info.find("debug_handle") != std::string::npos)); - ++pc; - } catch (const std::exception& e) { - break; - } - } -} - -TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return 2 * x + 5 - )JIT"); - Module b("B"); - b.register_module("A0", a); - b.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + 1 - )JIT"); - - std::stringstream ss; - b._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(B)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A0(A)::forward.aten::mul")); -} - -TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return x + 1 - )JIT"); - Module b("B"); - b.define(R"JIT( - def forward(self, x): - return x + 2 - )JIT"); - Module c("C"); - c.register_module("A0", a); - c.register_module("B0", b); - c.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + self.B0.forward(x) - )JIT"); - - std::stringstream ss; - c._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - AT_ASSERT(module_debug_info_set.count("top(C)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.B0(B)::forward.aten::add")); -} - TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) { auto runtime_bytecode_version = _get_runtime_bytecode_version(); AT_ASSERT( @@ -795,187 +657,6 @@ TEST(LiteInterpreterTest, isCompatibleFail) { AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR); } -#if !defined FB_XPLAT_BUILD -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -TEST(LiteInterpreterTest, SequentialModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return x + 1 - )JIT"); - Module b("B"); - b.define(R"JIT( - def forward(self, x): - return x + 2 - )JIT"); - Module c("C"); - c.register_module("A0", a); - c.register_module("B0", b); - c.define(R"JIT( - def forward(self, x): - return self.A0.forward(self.B0.forward(x)) - )JIT"); - - std::stringstream ss; - c._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - // class A(nn.Module): - // def __init__(self): - // super(A, self).__init__() - - // def forward(self, x): - // return x + 1 - - // class B(nn.Module): - // def __init__(self): - // super(B, self).__init__() - - // def forward(self, x): - // return x + 2 - - // class C(nn.Module): - // def __init__(self): - // super(C, self).__init__() - // self.A0 = A() - // self.B0 = B() - - // def forward(self, x): - // return self.A0.forward(self.B0.forward(x)) - - AT_ASSERT(module_debug_info_set.count("top(C)::.prim::Return")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.B0(B)::forward.aten::add")); -} - -TEST(LiteInterpreterTest, HierarchyModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return x + 1 - )JIT"); - Module b("B"); - b.register_module("A0", a); - b.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + 1 - )JIT"); - Module c("C"); - c.register_module("B0", b); - c.define(R"JIT( - def forward(self, x): - return self.B0.forward(x) + 1 - )JIT"); - - std::stringstream ss; - c._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - // There are 3 module information strings here. - // "top(C).forward": for the add operator in top. - // "top(C).B0(B).forward": for the add operator in B0. - // "top(C).B0(B).forward.A0(A).forward": for the add operator in A0. - AT_ASSERT(module_debug_info_set.count("top(C)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.B0(B)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(C)::.B0(B)::forward.A0(A)::forward.aten::add")); -} - -TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) { - Module a("A"); - a.define(R"JIT( - def forward(self, x): - return x + 5 - )JIT"); - Module b("B"); - b.register_module("A0", a); - b.register_module("A1", a); - b.define(R"JIT( - def forward(self, x): - return self.A0.forward(x) + self.A1.forward(x) - )JIT"); - - std::stringstream ss; - b._save_for_mobile(ss, {}, true); - mobile::Module bc = _load_for_mobile(ss); - - std::set module_debug_info_set; - size_t pc = 0; - while (true) { - try { - std::string module_info = bc.get_forward_method_debug_info(pc); - if (!module_info.empty() && - (module_info.find("debug_handle") == std::string::npos)) { - module_debug_info_set.insert(module_info); - } - ++pc; - } catch (const std::exception& e) { - break; - } - } - - // class A(nn.Module): - // def __init__(self): - // super(A, self).__init__() - - // def forward(self, x): - // return x + 5 - - // class B(nn.Module): - // def __init__(self): - // super(B, self).__init__() - // self.A0 = A() - // self.A1 = A() - - // def forward(self, x): - // return self.A0.forward(x) + self.A1.forward(x) - - // There are 3 module information strings here. - // "top(B).forward": for the add operator in top. - // "top(B).A0(A).forward": for the add operator in A0. - // "top(B).A1(A).forward": for the add operator in A1. - - AT_ASSERT(module_debug_info_set.count("top(B)::.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A0(A)::forward.aten::add")); - AT_ASSERT(module_debug_info_set.count( - "top(B)::.A1(A)::forward.aten::add")); -} -#endif // !defined(FB_XPLAT_BUILD) - TEST(LiteInterpreterTest, Eval) { std::vector inputs; @@ -1354,6 +1035,68 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) { testLiteModuleCompareResultTensors(m, inputs); } +void testDefaultArgsPinvWithOutArg(int num_args) { + Module m("m"); + if (num_args == 1) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, out=input) + )"); + } else if (num_args == 2) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, out=input) + )"); + } else if (num_args == 3) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, True, out=input) + )"); + } + + const int N = 28; + auto input = torch::range(1, N * N, 1); + input[0] = 10000; // a more stable matrix + input = input.view({N, N}); + auto ref = m.run_method("forward", input); + TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); + TORCH_CHECK(input.equal(ref.toTensor())); +} + +TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) { + // Test with different number of specified arguments + out arg. + // Arguments not specified take default value. + for (int num_args = 1; num_args <= 3; ++num_args) { + testDefaultArgsPinvWithOutArg(num_args); + } +} + +TEST(LiteInterpreterTest, DefaultArgsWithOutArg) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + torch.add(x, h, out=x) + )"); + + std::vector inputs; + auto input_x = 2 * torch::ones({}); + auto input_h = torch::ones({}); + auto ref = m.run_method("forward", input_x, input_h); + + std::stringstream ss; + + m._save_for_mobile(ss, {}, true); + mobile::Module bc = _load_for_mobile(ss); + bc.run_method("forward", input_x, input_h); + AT_ASSERT(input_x.equal(4 * torch::ones({}))); + + auto ops = _get_model_ops_and_info(ss); + auto op = ops.find("aten::add.out"); + TORCH_CHECK( + op != ops.end() && op->second.num_schema_args.has_value() && + op->second.num_schema_args.value() == 4); +} + TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { Module a("A"); a.define(R"( diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 82f70fee1dd20..305d36a476213 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -520,6 +520,28 @@ TEST(SchemaParserTest, NestedArrays) { .getElementType())); } +TEST(SchemaParserTest, OutVariant) { + auto schema_with_out = parseSchema( + "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)"); + ASSERT_TRUE(schema_with_out.arguments().at(1).is_out()); + ASSERT_TRUE(schema_with_out.arguments().at(2).is_out()); + + auto schema_without_out = + parseSchema("at::foo(Tensor self, *, int scalar) -> (int)"); + + for (const auto& arg : schema_without_out.arguments()) { + ASSERT_TRUE(!arg.is_out()); + } + + auto schema_with_is_write = parseSchema( + "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))"); + + for (const auto& arg : schema_with_is_write.arguments()) { + ASSERT_TRUE(!arg.is_out()); + } +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(SchemaParserTest, NamedReturns) { // named returns parseSchema("at::what(Tensor! i_will_be_written_to) -> ()"); @@ -1471,11 +1493,11 @@ TEST(NoneSchemaMatchTest, Basic) { RegisterOperators reg({ Operator( "prim::test_none() -> int?", - [](Stack* stack) { push(stack, IValue()); }, + [](Stack& stack) { push(stack, IValue()); }, aliasAnalysisFromSchema()), Operator( "prim::is_none(int? a) -> bool", - [](Stack* stack) { + [](Stack& stack) { IValue a = pop(stack); if (a.isNone()) { push(stack, true); diff --git a/test/cpp/jit/test_schema_matching.cpp b/test/cpp/jit/test_schema_matching.cpp index 31d332b718f53..c56d0bc28fe99 100644 --- a/test/cpp/jit/test_schema_matching.cpp +++ b/test/cpp/jit/test_schema_matching.cpp @@ -15,7 +15,7 @@ TEST(SchemaMatchingTest, VarType) { RegisterOperators reg({ Operator( "aten::test_vartype(t[] a, t b) -> (t)", - [](Stack* stack) { + [](Stack& stack) { c10::List list; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; @@ -54,7 +54,7 @@ TEST(SchemaMatchingTest, VarType2) { RegisterOperators reg({ Operator( "aten::test_vartype2(t a, t[] b) -> (t[])", - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; c10::List list; diff --git a/test/cpp/jit/test_stack_opt.cpp b/test/cpp/jit/test_stack_opt.cpp new file mode 100644 index 0000000000000..fea1bb5f81042 --- /dev/null +++ b/test/cpp/jit/test_stack_opt.cpp @@ -0,0 +1,308 @@ +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +TEST(StackOptTest, UseVariadicStack) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56), + %2: Float(56, 56, 56), + %3: Float(56, 56, 56), + %4: Float(56, 56, 56), + %5: Float(56, 56, 56)): + %10 : int = prim::Constant[value=0]() + %input : Tensor[] = prim::ListConstruct(%0, %1, %2, %3, %4, %5) + %stack : Float(5, 56, 56, 56) = aten::stack(%input, %10) + return (%stack) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(UseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // After replacing `aten::stack` with `prim::VarStack` we should have the + // following graph: + // + // graph(%0 : ..., + // %1 : ...): + // %zero : int = prim:Constant[value=0]() + // %varstack : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %5, %zero) + // return (%varstack) + testing::FileCheck() + .check_count("= prim::VarStack(", 1, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) + ->run(*graph); +} + +TEST(StackOptTest, UseVariadicStackReplaceMultiple) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56), + %2: Float(56, 56, 56), + %3: Float(56, 56, 56)): + %10 : int = prim::Constant[value=0]() + %input1 : Tensor[] = prim::ListConstruct(%0, %1) + %stack1 : Float(4, 56, 56, 56) = aten::stack(%input1, %10) + %input2 : Tensor[] = prim::ListConstruct(%2, %3) + %stack2 : Float(4, 56, 56, 56) = aten::stack(%input2, %10) + return (%stack1, %stack2) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(UseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // After full stack optimization we should have the following graph: + // + // graph(%0 : ..., + // %1 : ..., + // %2 : ..., + // %3 : ....): + // %zero : int = prim:Constant[value=0]() + // %varcat1 : Tensor = prim::VarStack(%0, %1, %zero) + // %varcat2 : Tensor = prim::VarStack(%2, %3, %zero) + // return (%varcat1, %varcat2) + testing::FileCheck() + .check_count("= prim::VarStack(", 2, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) + ->run(*graph); +} + +TEST(StackOptTest, UseVariadicStackWithMultipleListUses) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56)): + %2 : int = prim::Constant[value=0]() + %input : Tensor[] = prim::ListConstruct(%0, %1) + %stack : Float(2, 56, 56, 56) = aten::stack(%input, %2) + return (%stack, %input) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(UseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // After replacing `aten::stack` with `prim::VarStack` we should have the + // following graph: + // + // graph(%0 : ..., + // %1 : ...): + // %zero : int = prim:Constant[value=0]() + // %input : Tensor[] = prim::ListConstruct(%0, %1) + // %varcat : Tensor = prim::VarStack(%0, %1, %zero) + // return (%varcat, %input) + testing::FileCheck() + .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) + ->check_count("= prim::VarStack(", 1, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->run(*graph); +} + +TEST(StackOptTest, UseVariadicStackWithListMutationAfterCat) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56), + %2: Float(56, 56, 56)): + %10 : int = prim::Constant[value=0]() + %input : Tensor[] = prim::ListConstruct(%0, %1) + %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) + %11 : Tensor = aten::append(%input, %2) + return (%stack, %input) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(UseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // The input list to `aten::stack` is mutated only after `aten::stack` op. So, + // it should have been replaced with `prim::VarStack`. The transformed graph + // should look like the following: + // + // graph(%0 : ..., + // %1 : ..., + // %2 : ...): + // %3 : int = prim:Constant[value=0]() + // %4 : Tensor[] = prim::ListConstruct(%0, %1) + // %7 : Tensor = prim::VarStack(%0, %1, %3) + // %6 : Tensor = aten::append(%4, %2) + // return (%7, %4) + testing::FileCheck() + .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) + ->check_count("= prim::VarStack(", 1, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->run(*graph); +} + +TEST(StackOptTest, UseVariadicStackWithListMutationBeforeCat) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56), + %2: Float(56, 56, 56)): + %10 : int = prim::Constant[value=0]() + %input : Tensor[] = prim::ListConstruct(%0, %1) + %11 : Tensor = aten::append(%input, %2) + %stack : Float(3, 56, 56, 56) = aten::stack(%input, %10) + return (%stack) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + { + ASSERT_FALSE(UseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // No transformation should have happened since the `prim::ListConstruct` is + // mutated before `aten::stack`. + testing::FileCheck() + .check_count("= prim::ListConstruct(", 1, /*exactly*/ true) + ->check_count("= aten::stack(", 1, /*exactly*/ true) + ->check_count("= prim::VarStack(", 0, /*exactly*/ true) + ->run(*graph); + } + + { + ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // The mutation of the list must be removed and the `aten::stack` op must + // be replaced with the `prim::VarStack` op in the graph. The transformed + // graph should look like the following: + // + // graph(%0 : ..., + // %1 : ..., + // %2 : ...): + // %3 : int = prim:Constant[value=0]() + // %7 : Tensor = prim::VarStack(%0, %1, %2, %3) + // return (%7) + testing::FileCheck() + .check_count("= prim::VarStack(", 1, /*exactly*/ true) + ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->run(*graph); + } +} + +TEST(StackOptTest, UseVariadicStackWithMultipleListMutations) { + auto graph = std::make_shared(); + + const std::string input = + R"IR( + graph(%0: Float(56, 56, 56), + %1: Float(56, 56, 56), + %2: Float(56, 56, 56), + %3: Float(56, 56, 56), + %4: Float(56, 56, 56)): + %10 : int = prim::Constant[value=0]() + %input : Tensor[] = prim::ListConstruct(%0, %1) + %stack.1 : Float(5, 56, 56, 56) = aten::stack(%input, %10) + %11 : Tensor = aten::append(%input, %2) + %stack.2 : Float(5, 56, 56, 56) = aten::stack(%input, %10) + %12 : Tensor = aten::append(%input, %3) + %stack.3 : Float(5, 56, 56, 56) = aten::stack(%input, %10) + %13 : Tensor = aten::append(%input, %4) + %stack.4 : Float(5, 56, 56, 56) = aten::stack(%input, %10) + return (%stack.1, %stack.2, %stack.3, %stack.4) + )IR"; + parseIR(input, graph.get()); + std::vector inputs = { + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU), + at::rand({56, 56, 56}, at::kCPU)}; + auto orig_outputs = runGraph(graph, inputs); + + ASSERT_TRUE(RemoveListMutationAndUseVariadicStack(graph)); + graph->lint(); + auto opt_outputs = runGraph(graph, inputs); + ASSERT_TRUE(exactlyEqual(orig_outputs, opt_outputs)); + + // All the mutations of the list must be removed and the `aten::stack` ops + // must be replaced with `prim::VarStack` ops in the graph. The transformed + // graph should look like the following: + // + // graph(%0 : ..., + // %1 : ..., + // %2 : ..., + // %3 : ..., + // %4 : ...): + // %10 : int = prim:Constant[value=0]() + // %5 : Tensor = prim::VarStack(%0, %1, %10) + // %6 : Tensor = prim::VarStack(%0, %1, %2, %10) + // %7 : Tensor = prim::VarStack(%0, %1, %2, %3, %10) + // %8 : Tensor = prim::VarStack(%0, %1, %2, %3, %4, %10) + // return (%5, %6, %7, %8) + testing::FileCheck() + .check_count("= prim::VarStack(", 4, /*exactly*/ true) + ->check_count("= prim::ListConstruct(", 0, /*exactly*/ true) + ->check_count("= aten::stack(", 0, /*exactly*/ true) + ->run(*graph); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_union.cpp b/test/cpp/jit/test_union.cpp new file mode 100644 index 0000000000000..f35acd35d1ed6 --- /dev/null +++ b/test/cpp/jit/test_union.cpp @@ -0,0 +1,149 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +class UnionTypeTest : public ::testing::Test { + public: + // None + const TypePtr none = NoneType::get(); + + // List[str] + const TypePtr l1 = ListType::ofStrings(); + + // Optional[int] + const TypePtr opt1 = OptionalType::create(IntType::get()); + + // Optional[float] + const TypePtr opt2 = OptionalType::create(FloatType::get()); + + // Optional[List[str]] + const TypePtr opt3 = OptionalType::create(ListType::ofStrings()); + + // Tuple[Optional[int], int] + const TypePtr tup1 = + TupleType::create({OptionalType::create(IntType::get()), IntType::get()}); + + // Tuple[int, int] + const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()}); + + bool hasType(UnionTypePtr u, TypePtr t) { + auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t); + return res != u->getTypes().end(); + } +}; + +TEST_F(UnionTypeTest, UnionOperatorEquals) { + const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()}); + + // Same thing, but using different TypePtrs + const TypePtr l1_ = ListType::ofStrings(); + const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()}); + const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()}); + + ASSERT_TRUE(*u1 == *u2); +} + +TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) { + // Goal: Union[int, float, None] + const UnionTypePtr u = UnionType::create({opt1, opt2}); + + ASSERT_EQ(u->getTypes().size(), 3); + ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); +} + +TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) { + // Goal: Union[int, None] + const UnionTypePtr u = UnionType::create({opt1, IntType::get()}); + + ASSERT_EQ(u->getTypes().size(), 2); + ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get())); +} + +TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) { + // Goal: Union[Tuple[Optional[int], int], str] + const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2}); + + ASSERT_EQ(u->getTypes().size(), 2); + ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, tup1)); +} + +TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) { + // Goal: Union[List[str], str] + const UnionTypePtr u = UnionType::create({l1, StringType::get()}); + + ASSERT_EQ(u->getTypes().size(), 2); + ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); +} + +TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) { + // Goal: Union[List[str], None, str] + const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()}); + + ASSERT_EQ(u->getTypes().size(), 3); + ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get())); + ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings())); +} + +TEST_F(UnionTypeTest, Subtyping_NumberType) { + // Union[int, float, Complex] + const UnionTypePtr union1 = + UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()}); + + // Union[int, float, Complex, None] + const UnionTypePtr union2 = UnionType::create( + {IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()}); + + const NumberTypePtr num = NumberType::get(); + + ASSERT_TRUE(num->isSubtypeOf(union1)); + ASSERT_TRUE(union1->isSubtypeOf(num)); + ASSERT_TRUE(*num == *union1); + + ASSERT_TRUE(num->isSubtypeOf(union2)); + ASSERT_FALSE(union2->isSubtypeOf(num)); + ASSERT_FALSE(*num == *union2); +} + +TEST_F(UnionTypeTest, Subtyping_OptionalType) { + // Union[int, None] + const UnionTypePtr union1 = + UnionType::create({IntType::get(), NoneType::get()}); + + // Union[int, str, None] + const UnionTypePtr union2 = + UnionType::create({IntType::get(), StringType::get(), NoneType::get()}); + + // Union[int, str, List[str]] + const UnionTypePtr union3 = UnionType::create( + {IntType::get(), StringType::get(), ListType::ofStrings()}); + + ASSERT_TRUE(none->isSubtypeOf(opt1)); + ASSERT_TRUE(none->isSubtypeOf(union1)); + ASSERT_TRUE(none->isSubtypeOf(union2)); + ASSERT_FALSE(none->isSubtypeOf(union3)); + + ASSERT_FALSE(opt1->isSubtypeOf(none)); + ASSERT_TRUE(opt1->isSubtypeOf(union1)); + ASSERT_TRUE(opt1->isSubtypeOf(union2)); + ASSERT_FALSE(opt1->isSubtypeOf(union3)); + + ASSERT_FALSE(union1->isSubtypeOf(none)); + ASSERT_TRUE(union1->isSubtypeOf(opt1)); + ASSERT_TRUE(union1->isSubtypeOf(union2)); + ASSERT_FALSE(union1->isSubtypeOf(union3)); + + ASSERT_FALSE(union2->isSubtypeOf(union1)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_utils.cpp b/test/cpp/jit/test_utils.cpp index 7750ba8f10fee..8da101e99bbdf 100644 --- a/test/cpp/jit/test_utils.cpp +++ b/test/cpp/jit/test_utils.cpp @@ -123,6 +123,21 @@ std::shared_ptr build_mobile_export_analysis_graph() { return g; } +std::shared_ptr build_mobile_export_with_out() { + const auto graph_string = R"IR( + graph(%x.1 : Tensor, + %y.1 : Tensor): + %8 : NoneType = prim::Constant() + %6 : int = prim::Constant[value=1]() + %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1) + return (%8))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + return g; +} + std::shared_ptr build_mobile_export_analysis_graph_nested() { // this is pretty much same test as build_mobile_export_analysis_graph(), // but some aten::slice operators are hidden under block statement to check @@ -198,6 +213,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector inputs) { } return diff.abs().max().item() < 2e-6 * maxValue; } + bool almostEqual(const at::Tensor& a, const at::Tensor& b) { return checkRtol(a - b, {a, b}); } @@ -206,6 +222,20 @@ bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { return (a - b).abs().max().item() == 0.f; } +bool exactlyEqual( + const std::vector& a, + const std::vector& b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < a.size(); ++i) { + if (!exactlyEqual(a[i], b[i])) { + return false; + } + } + return true; +} + std::pair lstm( at::Tensor input, at::Tensor hx, @@ -243,10 +273,27 @@ RegisterOperators reg({ // because it always produces empty Tensors. Operator( "prim::MakeTestTensor() -> Tensor", - [](Stack* stack) { push(stack, at::Tensor()); }, + [](Stack& stack) { push(stack, at::Tensor()); }, aliasAnalysisFromSchema()), }); } // namespace +std::vector runGraph( + std::shared_ptr graph, + const std::vector& inputs) { + std::vector stack = fmap(inputs); + Code code(graph, "test"); + InterpreterState(code).run(stack); + TORCH_INTERNAL_ASSERT(!stack.empty()); + // Graph outputs that are handled below: + // * A list of Tensors. + // * 1 Tensor. + if (stack.front().isTensorList()) { + return stack.front().toTensorVector(); + } + TORCH_INTERNAL_ASSERT(stack.front().isTensor()); + return {stack.front().toTensor()}; +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_utils.h b/test/cpp/jit/test_utils.h index 676759dca480f..1a1e1b82b10e8 100644 --- a/test/cpp/jit/test_utils.h +++ b/test/cpp/jit/test_utils.h @@ -74,6 +74,7 @@ std::pair runGradient( std::shared_ptr build_lstm(); std::shared_ptr build_mobile_export_analysis_graph(); +std::shared_ptr build_mobile_export_with_out(); std::shared_ptr build_mobile_export_analysis_graph_with_vararg(); std::shared_ptr build_mobile_export_analysis_graph_nested(); std::shared_ptr build_mobile_export_analysis_graph_non_const(); @@ -88,6 +89,13 @@ bool checkRtol(const at::Tensor& diff, const std::vector inputs); bool almostEqual(const at::Tensor& a, const at::Tensor& b); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b); +bool exactlyEqual( + const std::vector& a, + const std::vector& b); + +std::vector runGraph( + std::shared_ptr graph, + const std::vector& inputs); std::pair lstm( at::Tensor input, diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp index d761645b25b3f..8de395fe92796 100644 --- a/test/cpp/tensorexpr/test_approx.cpp +++ b/test/cpp/tensorexpr/test_approx.cpp @@ -11,7 +11,7 @@ using namespace torch::indexing; namespace te = torch::jit::tensorexpr; -static void vectorize(te::LoopNest* ln, te::Tensor* target, int width) { +static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { auto loops = ln->getLoopStmtsFor(target); te::ForPtr inner, tail; ln->splitWithTail(loops[0], width, &inner, &tail); @@ -30,10 +30,9 @@ std::string diffs(const at::Tensor& a, const at::Tensor& b) { } TEST(Approx, log_vml) { - te::KernelScope ks; te::VarHandle N("N", te::kInt); te::Placeholder A("A", te::kFloat, {N}); - te::Tensor* B = te::Compute( + te::Tensor B = te::Compute( "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); te::LoopNest ln({B}); diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp index 9eb141250cb35..040b7b0a920fb 100644 --- a/test/cpp/tensorexpr/test_aten.cpp +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -15,7 +15,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(ATen, _cast_Float) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -43,7 +42,6 @@ TEST(ATen, _cast_Float) { } TEST(ATen, negInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -71,7 +69,6 @@ TEST(ATen, negInt) { } TEST(ATen, negFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -99,7 +96,6 @@ TEST(ATen, negFloat) { } TEST(ATen, addInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -136,7 +132,6 @@ TEST(ATen, addInt) { } TEST(ATen, addFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -173,7 +168,6 @@ TEST(ATen, addFloat) { } TEST(ATen, subInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -210,7 +204,6 @@ TEST(ATen, subInt) { } TEST(ATen, subFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -247,7 +240,6 @@ TEST(ATen, subFloat) { } TEST(ATen, lerp) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -284,7 +276,6 @@ TEST(ATen, lerp) { } TEST(ATen, addcmulInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -326,7 +317,6 @@ TEST(ATen, addcmulInt) { } TEST(ATen, addcmulFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -368,7 +358,6 @@ TEST(ATen, addcmulFloat) { } TEST(ATen, mulInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -400,7 +389,6 @@ TEST(ATen, mulInt) { } TEST(ATen, mulFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -432,7 +420,6 @@ TEST(ATen, mulFloat) { } TEST(ATen, divInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -464,7 +451,6 @@ TEST(ATen, divInt) { } TEST(ATen, divFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -496,7 +482,6 @@ TEST(ATen, divFloat) { } TEST(ATen, maxInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -528,7 +513,6 @@ TEST(ATen, maxInt) { } TEST(ATen, maxFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -560,7 +544,6 @@ TEST(ATen, maxFloat) { } TEST(ATen, minInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -592,7 +575,6 @@ TEST(ATen, minInt) { } TEST(ATen, minFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -624,7 +606,6 @@ TEST(ATen, minFloat) { } void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -651,7 +632,6 @@ void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { } TEST(ATen, reluInt) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kInt)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kInt)); @@ -678,7 +658,6 @@ TEST(ATen, reluInt) { } TEST(ATen, reluFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -707,7 +686,6 @@ TEST(ATen, reluFloat) { } TEST(ATen, logFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -734,7 +712,6 @@ TEST(ATen, logFloat) { } TEST(ATen, fastLogFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -766,7 +743,6 @@ TEST(ATen, fastLogFloat) { } TEST(ATen, fastTanhFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -798,7 +774,6 @@ TEST(ATen, fastTanhFloat) { } TEST(ATen, fastSigmoidFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -831,7 +806,6 @@ TEST(ATen, fastSigmoidFloat) { } TEST(ATen, log10Float) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -858,7 +832,6 @@ TEST(ATen, log10Float) { } TEST(ATen, log2Float) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -885,7 +858,6 @@ TEST(ATen, log2Float) { } TEST(ATen, expFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -913,7 +885,6 @@ TEST(ATen, expFloat) { } TEST(ATen, erfFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -941,7 +912,6 @@ TEST(ATen, erfFloat) { } TEST(ATen, cosFloat) { - KernelScope kernel_scope; const int kTotalSize = 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -969,7 +939,6 @@ TEST(ATen, cosFloat) { } TEST(ATen, eqInt) { - KernelScope kernel_scope; constexpr int N = 128; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -995,7 +964,6 @@ TEST(ATen, eqInt) { } TEST(ATen, geInt) { - KernelScope kernel_scope; constexpr int N = 128; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -1021,7 +989,6 @@ TEST(ATen, geInt) { } TEST(ATen, gtInt) { - KernelScope kernel_scope; constexpr int N = 128; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -1047,7 +1014,6 @@ TEST(ATen, gtInt) { } TEST(ATen, leInt) { - KernelScope kernel_scope; constexpr int N = 128; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -1073,7 +1039,6 @@ TEST(ATen, leInt) { } TEST(ATen, ltInt) { - KernelScope kernel_scope; constexpr int N = 128; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp index fcfa8cec4bc49..2eb0dfb997da8 100644 --- a/test/cpp/tensorexpr/test_boundsinference.cpp +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -46,10 +46,9 @@ TEST(BoundsInference, _1) { // b[i] = a[i] // For this loop bounds inference should yield the following: // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} - KernelScope kernel_scope; ExprHandle n(100); Placeholder a(BufHandle("a", {n}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); auto bounds_info = inferBounds(l.root_stmt()); @@ -60,9 +59,9 @@ TEST(BoundsInference, _1) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 99}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); } TEST(BoundsInference, _2) { @@ -71,10 +70,9 @@ TEST(BoundsInference, _2) { // b[i] = a[i] // For this loop bounds inference should yield the following: // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} - KernelScope kernel_scope; VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); auto bounds_info = inferBounds(l.root_stmt()); @@ -85,9 +83,9 @@ TEST(BoundsInference, _2) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, -1}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, -1}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); } TEST(BoundsInference, _3) { @@ -96,10 +94,9 @@ TEST(BoundsInference, _3) { // b[i] = a[i] * a[i+10] // For this loop bounds inference should yield the following: // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} - KernelScope kernel_scope; ExprHandle n(100); Placeholder a(BufHandle("a", {n + 10}, kFloat)); - Tensor* b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { + Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); LoopNest l({b}); @@ -111,9 +108,9 @@ TEST(BoundsInference, _3) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 109}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 99}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); } TEST(BoundsInference, _4) { @@ -125,17 +122,16 @@ TEST(BoundsInference, _4) { // for y in 0..200: // for x in 0..320: // c[y,x] = a[y,x] * b[y,x] - KernelScope kernel_scope; ExprHandle W(320); ExprHandle H(200); Placeholder a(BufHandle("a", {H, W}, kFloat)); - Tensor* b = Compute( + Tensor b = Compute( "b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { return x * y; }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y, x) * b->load(y, x); + return a.load(y, x) * b.load(y, x); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -149,13 +145,13 @@ TEST(BoundsInference, _4) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 199}, {0, 319}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 199}, {0, 319}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 199}, {0, 319}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); } { // Infer bounds on the inner loop scope @@ -166,13 +162,13 @@ TEST(BoundsInference, _4) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {0, 319}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 319}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 319}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); } { // Infer bounds on the inner loop body's scope @@ -183,13 +179,13 @@ TEST(BoundsInference, _4) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); } } @@ -205,10 +201,9 @@ TEST(BoundsInference, _5) { // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] // for i_tail in 0..100%16: // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; - KernelScope kernel_scope; ExprHandle n(100); Placeholder a(BufHandle("a", {n}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); @@ -229,9 +224,9 @@ TEST(BoundsInference, _5) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 95}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 95}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); } { // Verify inferred bounds for the tail loop @@ -242,9 +237,9 @@ TEST(BoundsInference, _5) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{96, 99}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{96, 99}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); } } @@ -257,19 +252,18 @@ TEST(BoundsInference, _6) { // for y in 0..20: // for x in 0..32: // c[y,x] = a[y+100,x+100] * b[y*2,x*5] - KernelScope kernel_scope; ExprHandle W(320); ExprHandle H(200); ExprHandle CW(32); ExprHandle CH(20); Placeholder a(BufHandle("a", {H, W}, kFloat)); - Tensor* b = Compute( + Tensor b = Compute( "b", {{H, "y"}, {W, "x"}}, [&](const VarHandle& y, const VarHandle& x) { return x * y; }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{CH, "y"}, {CW, "x"}}, [&](const VarHandle& y, const VarHandle& x) { - return a.load(y + 100, x + 100) * b->load(y * 2, x * 5); + return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); }); LoopNest l({c}); std::vector loops = l.getLoopStmtsFor(c); @@ -283,13 +277,13 @@ TEST(BoundsInference, _6) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{100, 119}, {100, 131}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 38}, {0, 155}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 19}, {0, 31}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); } { // Infer bounds on the inner loop scope @@ -300,13 +294,13 @@ TEST(BoundsInference, _6) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {100, 131}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {0, 155}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {0, 31}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); } { // Infer bounds on the inner loop body's scope @@ -317,23 +311,22 @@ TEST(BoundsInference, _6) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{-1, -1}, {-1, -1}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kLoad); - verifyConstBounds(bounds_info.at(b->buf())[0], {{-1, -1}, {-1, -1}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{-1, -1}, {-1, -1}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); } } TEST(BoundsInference, Adjacent) { - KernelScope kernel_scope; ExprHandle H(6); Placeholder a(BufHandle("a", {20}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{H, "x"}}, [&](const VarHandle& x) { return a.load(x + H); }); LoopNest l({b, c}); std::vector loops = NodeFinder::find(l.root_stmt()); @@ -348,9 +341,9 @@ TEST(BoundsInference, Adjacent) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 5}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); } { // Infer bounds on the inner loop scope @@ -362,9 +355,9 @@ TEST(BoundsInference, Adjacent) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{6, 11}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); } { // Infer bounds on the high level program. @@ -377,24 +370,23 @@ TEST(BoundsInference, Adjacent) { ASSERT_EQ(bounds_info.at(a.data())[0].kind, kLoad); verifyConstBounds(bounds_info.at(a.data())[0], {{0, 11}}); - ASSERT_EQ(bounds_info.at(b->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(b->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(b->buf())[0], {{0, 5}}); + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); - ASSERT_EQ(bounds_info.at(c->buf()).size(), 1); - ASSERT_EQ(bounds_info.at(c->buf())[0].kind, kStore); - verifyConstBounds(bounds_info.at(c->buf())[0], {{0, 5}}); + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); } } TEST(BoundsInference, MultipleTopLoopLoad) { - KernelScope kernel_scope; Placeholder a(BufHandle("a", {100}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{64, "x"}}, [&](const VarHandle& x) { return a.load(x); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{32, "x"}}, [&](const VarHandle& x) { return a.load(x + 10); }); - Tensor* d = Compute( + Tensor d = Compute( "d", {{96, "x"}}, [&](const VarHandle& x) { return a.load(x + 2); }); LoopNest l({b, c, d}); @@ -418,7 +410,7 @@ TEST(BoundsInference, MultipleTopLoopLoad) { // b, c, d only written. { - auto bounds = bounds_info[b->buf()]; + auto bounds = bounds_info[b.buf()]; ASSERT_EQ(bounds.size(), 1); auto bound = bounds[0]; ASSERT_EQ(bound.kind, TensorAccessKind::kStore); @@ -426,7 +418,7 @@ TEST(BoundsInference, MultipleTopLoopLoad) { verifyConstBounds(bound, {{0, 63}}); } { - auto bounds = bounds_info[c->buf()]; + auto bounds = bounds_info[c.buf()]; ASSERT_EQ(bounds.size(), 1); auto bound = bounds[0]; ASSERT_EQ(bound.kind, TensorAccessKind::kStore); @@ -434,7 +426,7 @@ TEST(BoundsInference, MultipleTopLoopLoad) { verifyConstBounds(bound, {{0, 31}}); } { - auto bounds = bounds_info[d->buf()]; + auto bounds = bounds_info[d.buf()]; ASSERT_EQ(bounds.size(), 1); auto bound = bounds[0]; ASSERT_EQ(bound.kind, TensorAccessKind::kStore); @@ -444,7 +436,6 @@ TEST(BoundsInference, MultipleTopLoopLoad) { } TEST(BoundsInference, MultipleTopLoopStore) { - KernelScope kernel_scope; BufHandle a("a", {100}, kFloat); BufHandle b("b", {100}, kFloat); BufHandle c("c", {100}, kFloat); @@ -504,26 +495,24 @@ TEST(BoundsInference, MultipleTopLoopStore) { } TEST(BoundsInference, CacheReads) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 30, j + 3); + return A.load(i + 30, j + 3); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}); auto bounds_info_before = inferBounds(l.root_stmt()); StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; - LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); auto bounds_info_after = inferBounds(l.root_stmt()); @@ -570,8 +559,7 @@ TEST(BoundsInference, CacheReads) { } TEST(BoundsInference, Flattened) { - KernelScope kernel_scope; - Tensor* b = Compute( + Tensor b = Compute( "b", {{3, "z"}, {4, "y"}, {5, "x"}}, [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { @@ -585,7 +573,7 @@ TEST(BoundsInference, Flattened) { // There's only one buffer. ASSERT_EQ(bounds_info.size(), 1); - auto& TABI = bounds_info[b->buf()][0]; + auto& TABI = bounds_info[b.buf()][0]; ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); // Flattened bounds should have a single dimension. ASSERT_EQ(TABI.start.size(), 1); @@ -597,7 +585,6 @@ TEST(BoundsInference, Flattened) { } TEST(BoundsInference, GetPotentialHazards) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -649,13 +636,11 @@ TEST(BoundsInference, GetPotentialHazards) { } TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return (i + 1) * (j + 1); }); @@ -677,15 +662,13 @@ TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { } TEST(BoundsInference, GetPotentialHazardsLoopCall) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{64, "i"}, {64, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i, j) + 5; + return A.load(i, j) + 5; }); LoopNest l({A, B}); @@ -704,9 +687,7 @@ TEST(BoundsInference, GetPotentialHazardsLoopCall) { } TEST(BoundsInference, GetPotentialHazardsLoopSplit) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); @@ -731,8 +712,6 @@ TEST(BoundsInference, GetPotentialHazardsLoopSplit) { } TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -755,8 +734,6 @@ TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { } TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -778,8 +755,6 @@ TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { } TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -803,8 +778,6 @@ TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { } TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -827,8 +800,6 @@ TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { } TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -868,8 +839,6 @@ TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { } TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -909,8 +878,6 @@ TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { } TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -949,8 +916,6 @@ TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { } TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -978,8 +943,6 @@ TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { } TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { - KernelScope kernel_scope; - // Input IR: // for (int k = 0; k < 100; k++) { // B[k] = 20 * A[99-k]; @@ -1007,8 +970,6 @@ TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { } TEST(BoundsInference, HasConflictingOverlapWithLoads) { - KernelScope kernel_scope; - // Input IR: // for (int k = 10; k < 100; k++) { // B[k] = 20 * A[99-k]; @@ -1041,8 +1002,6 @@ TEST(BoundsInference, HasConflictingOverlapWithLoads) { } TEST(BoundsInference, IsOverlapping) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; i++) { // A[i] = i * 10; // storeA1 diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp index 63881d0d33cae..19372779094a6 100644 --- a/test/cpp/tensorexpr/test_conv.cpp +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -21,7 +21,6 @@ static at::Tensor genTestData(c10::IntArrayRef args) { #ifdef TORCH_ENABLE_LLVM TEST(Conv, DepthwiseConv2D) { - te::KernelScope kernel_scope; constexpr int N = 1, C = 72, H = 56, W = 56; constexpr int K = 72, R = 3, S = 3; constexpr int kPad = 1, kStride = 2, kGroups = C; @@ -30,7 +29,7 @@ TEST(Conv, DepthwiseConv2D) { te::Placeholder input("input", te::kFloat, {N, C, H, W}); te::Placeholder weight("weight", te::kFloat, {K, CperG, R, S}); te::Placeholder bias("bias", te::kFloat, {K}); - te::Tensor* output = te::conv2d_depthwise( + te::Tensor output = te::conv2d_depthwise( input.handle(), weight.handle(), bias.handle(), kStride, kPad, kGroups); te::LoopNest loop({output}); @@ -53,7 +52,6 @@ TEST(Conv, DepthwiseConv2D) { } TEST(Conv, DepthwiseConv2DNoBias) { - te::KernelScope kernel_scope; constexpr int N = 1, C = 72, H = 56, W = 56; constexpr int K = 72, R = 3, S = 3; constexpr int kPad = 1, kStride = 2, kGroups = C; @@ -61,7 +59,7 @@ TEST(Conv, DepthwiseConv2DNoBias) { te::Placeholder input("input", te::kFloat, {N, C, H, W}); te::Placeholder weight("weight", te::kFloat, {K, CperG, R, S}); - te::Tensor* output = te::conv2d_depthwise( + te::Tensor output = te::conv2d_depthwise( input.handle(), weight.handle(), kStride, kPad, kGroups); te::LoopNest loop({output}); @@ -80,7 +78,6 @@ TEST(Conv, DepthwiseConv2DNoBias) { } TEST(Conv, DepthwiseConv2DDynamicShapes) { - te::KernelScope kernel_scope; te::VarHandle N_var("N", te::kInt); te::VarHandle C_var("C", te::kInt); te::VarHandle H_var("H", te::kInt); @@ -96,7 +93,7 @@ TEST(Conv, DepthwiseConv2DDynamicShapes) { te::Placeholder input("input", te::kFloat, {N_var, C_var, H_var, W_var}); te::Placeholder weight( "weight", te::kFloat, {K_var, CperG_var, R_var, S_var}); - te::Tensor* output = te::conv2d_depthwise( + te::Tensor output = te::conv2d_depthwise( input.handle(), weight.handle(), N_var, @@ -164,8 +161,6 @@ TEST(Conv, DepthwiseConv2DDynamicShapes) { #endif TEST(Conv, Conv2D) { - te::KernelScope kernel_scope; - // Input dimensions. constexpr int N = 1; constexpr int C = 3; @@ -195,7 +190,7 @@ TEST(Conv, Conv2D) { te::Placeholder inputB(te::BufHandle("input", {N, C, H, W}, te::kFloat)); te::Placeholder filterB(te::BufHandle("filter", {K, C, R, S}, te::kFloat)); - te::Tensor* conv = te::Reduce( + te::Tensor conv = te::Reduce( "conv", {{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}}, te::Sum(), diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp index 82ea40d995f29..d40caa126e572 100644 --- a/test/cpp/tensorexpr/test_cpp_codegen.cpp +++ b/test/cpp/tensorexpr/test_cpp_codegen.cpp @@ -1,10 +1,11 @@ #include -#include +#include "test/cpp/tensorexpr/test_base.h" #include -#include +#include #include +#include #include namespace torch { @@ -12,46 +13,245 @@ namespace jit { using namespace torch::jit::tensorexpr; -TEST(CppPrinter, AllocateOnStackThenFree) { - KernelScope kernel_scope; - std::vector dims = {alloc(2), alloc(3)}; - BufPtr buf = alloc("x", dims, kInt); - AllocatePtr alloc_ = alloc(buf); - FreePtr free_ = alloc(buf); - BlockPtr block = Block::make({alloc_, free_}); - - std::stringstream ss; - CppPrinter printer(&ss); - printer.visit(block); - const std::string expected = R"( - # CHECK: { - # CHECK: int x[6]; - # CHECK: } +#define STR_CHECK(node, expected) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + ASSERT_EQ(ss.str(), expected) + +#define FILE_CHECK(node, pattern) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + torch::jit::testing::FileCheck().run(pattern, ss.str()) + +TEST(CppPrinter, IntImm) { + auto i = alloc(10); + STR_CHECK(i, "10"); +} + +TEST(CppPrinter, FloatImm) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, FloatImm1) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, DoubleImm) { + auto d = alloc(10); + STR_CHECK(d, "10.0"); +} + +TEST(CppPrinter, DoubleImm1) { + auto d = alloc(10.1); + STR_CHECK(d, "10.1"); +} + +TEST(CppPrinter, HalfImm) { + auto h = alloc(10); + STR_CHECK(h, "10"); +} + +TEST(CppPrinter, Add) { + auto add = alloc(alloc(1), alloc(2)); + STR_CHECK(add, "1 + 2"); +} + +TEST(CppPrinter, AddExpr1) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr2) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "0 * 1 + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr3) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc
(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + 2 / 3"); +} + +TEST(CppPrinter, Mod) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "1 % 2"); +} + +TEST(CppPrinter, ModFloat) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "std::fmod(1.f, 2.f)"); +} + +TEST(CppPrinter, Max) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1, 2)"); +} + +TEST(CppPrinter, MaxFloat) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1.f, 2.f)"); +} + +TEST(CppPrinter, MaxHalf) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "(1 < 2) ? 2 : 1"); +} + +TEST(CppPrinter, And) { + auto v = alloc(alloc(1), alloc(2)); + STR_CHECK(v, "1 & 2"); +} + +TEST(CppPrinter, CompareSelect) { + auto cs = alloc( + alloc(1), + alloc(2), + alloc(1), + alloc(2), + CompareSelectOperation::kLE); + STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); +} + +TEST(CppPrinter, IfThenElse) { + auto cond = alloc(alloc(1), alloc(2)); + auto true_value = alloc(alloc(0), alloc(1)); + auto false_value = alloc(alloc(2), alloc(3)); + auto v = alloc(cond, true_value, false_value); + STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); +} + +TEST(CppPrinter, AllocateFree) { + BufHandle buf("x", {2, 3}, kInt); + AllocatePtr alloc = Allocate::make(buf); + FreePtr free = Free::make(buf); + BlockPtr block = Block::make({alloc, free}); + + const std::string pattern = R"( + # CHECK: { + # CHECK: int* x = static_cast(malloc(24)); + # CHECK: free(x); + # CHECK: } + )"; + FILE_CHECK(block, pattern); +} + +TEST(CppPrinter, LoadStore) { + Placeholder a(BufHandle("A", {2, 3}, kInt)); + Placeholder b(BufHandle("B", {3, 4}, kInt)); + auto store = b.store({2, 2}, a.load(1, 1)); + STR_CHECK( + store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); +} + +TEST(CppPrinter, Var) { + auto var = alloc("x", kInt); + STR_CHECK(var, "x"); +} + +TEST(CppPrinter, Cast) { + auto cast = alloc(kFloat, alloc(1)); + STR_CHECK(cast, "static_cast(1)"); +} + +TEST(CppPrinter, BitCast) { + auto cast = alloc(kInt, alloc(20)); + STR_CHECK(cast, "std::bitcast(20.f)"); +} + +TEST(CppPrinter, Let) { + auto var = alloc("x", kFloat); + auto val = alloc(2); + auto let = alloc(var, val); + STR_CHECK(let, "float x = 2.f;\n"); +} + +TEST(CppPrinter, For) { + constexpr int N = 1024; + Placeholder a(BufHandle("A", {N}, kInt)); + Placeholder b(BufHandle("B", {N}, kInt)); + Placeholder c(BufHandle("C", {N}, kInt)); + VarHandle i("i", kInt); + auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); + const std::string pattern = R"( + # CHECK: for (int i = 0; i < 1024; i++) { + # CHECK: C[i] = (A[i]) + (B[i]); + # CHECK: } )"; - torch::jit::testing::FileCheck().run(expected, ss.str()); -} - -TEST(CppPrinter, AllocateOnHeapThenFree) { - KernelScope kernel_scope; - std::vector dims = { - alloc(20), alloc(50), alloc(3)}; - BufPtr buf = alloc("y", dims, kLong); - AllocatePtr alloc_ = alloc(buf); - FreePtr free_ = alloc(buf); - BlockPtr block = Block::make({alloc_, free_}); - - std::stringstream ss; - CppPrinter printer(&ss); - printer.visit(block); - // size(long) = 8; - // dim0 * dim1 * dim2 * size(long) = 24000. - const std::string expected = R"( - # CHECK: { - # CHECK: int64_t* y = static_cast(malloc(24000)); - # CHECK: free(y); + FILE_CHECK(f, pattern); +} + +TEST(CppPrinter, Cond) { + Placeholder x(BufHandle("X", {1}, kInt)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + const std::string pattern = R"( + # CHECK: if (((X[0] < 10) ? 1 : 0)) { + # CHECK: X[0] = (X[0]) + 1; + # CHECK: } else { + # CHECK: X[0] = (X[0]) - 1; # CHECK: } )"; - torch::jit::testing::FileCheck().run(expected, ss.str()); + FILE_CHECK(cond, pattern); +} + +TEST(CppPrinter, Intrinsics) { + const std::unordered_set> unsupported_ops{ + kRand, kSigmoid}; + for (int i = 0; i < kMaxIntrinsicsOp; i++) { + IntrinsicsOp op = static_cast(i); + if (unsupported_ops.count(op)) { + continue; + } + + if (Intrinsics::OpArgCount(op) == 1) { + auto v = alloc(op, alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); + } else { + auto v = + alloc(op, alloc(1.0f), alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); + } + } +} + +TEST(CppPrinter, ExternalCall) { + std::vector dims{alloc(2), alloc(2)}; + auto output = alloc("out", dims, kFloat); + auto buf_arg1 = alloc("a", dims, kFloat); + auto buf_arg2 = alloc("b", dims, kFloat); + auto scalar_arg = alloc(alloc(1), alloc(2)); + std::vector buf_args{buf_arg1, buf_arg2}; + std::vector scalar_args{scalar_arg}; + auto call = + alloc(output, "nnc_aten_matmul", buf_args, scalar_args); + const std::string pattern = R"( + # CHECK: { + # CHECK: void* buf_ptrs[]{out, a, b}; + # CHECK: int64_t buf_ranks[]{2, 2, 2}; + # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; + # CHECK: int8_t buf_dtypes[]{6, 6, 6}; + # CHECK: int64_t extra_args[]{1 + 2}; + # CHECK: nnc_aten_matmul( + # CHECK: 3, + # CHECK: buf_ptrs, + # CHECK: buf_ranks, + # CHECK: buf_dims, + # CHECK: buf_dtypes, + # CHECK: 1, + # CHECK: extra_args); + # CHECK: } + )"; + FILE_CHECK(call, pattern); } } // namespace jit diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 3ca6e0d9f5c3a..164ff772d5b46 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -27,14 +27,13 @@ using namespace torch::jit::tensorexpr; template static void testCudaTestVectorAdd01_impl() { - KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; const int block_size = 128; Dtype dtype = ToDtype(); Placeholder a_buf("a", dtype, {num_iter, block_count, block_size}); Placeholder b_buf("b", dtype, {num_iter, block_count, block_size}); - Tensor* c = Compute( + Tensor c = Compute( "c", { {num_iter, "n"}, @@ -93,13 +92,12 @@ float sigmoid(float x) { } TEST(Cuda, Sigmoid_CUDA) { - KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; const int block_size = 128; Dtype dtype = ToDtype(); Placeholder a_buf("a", dtype, {num_iter, block_count, block_size}); - Tensor* c = Compute( + Tensor c = Compute( "c", { {num_iter, "n"}, @@ -162,10 +160,9 @@ TEST(Cuda, TestVectorAdd01_CUDA) { } static void testCudaTestVectorAdd02_impl(int N, int block_size) { - KernelScope kernel_scope; Placeholder a_buf("a", kFloat, {N}); Placeholder b_buf("b", kFloat, {N}); - Tensor* c = Compute( + Tensor c = Compute( "c", { {N, "N"}, @@ -222,10 +219,9 @@ TEST(Cuda, TestVectorAdd02_CUDA) { } TEST(Cuda, HalfCast_CUDA) { - KernelScope ks; auto half = ToDtype(); Placeholder a("a", half, {4}); - Tensor* b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { + Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { return Cast::make(kFloat, a.load(i)); }); @@ -261,13 +257,12 @@ TEST(Cuda, HalfCast_CUDA) { } TEST(Cuda, DynamicShape2D_CUDA) { - KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); Placeholder a(BufHandle("a", {m, n}, kFloat)); Placeholder b(BufHandle("b", {m, n}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j) + b.load(i, j); }); @@ -324,11 +319,10 @@ TEST(Cuda, DynamicShape2D_CUDA) { } TEST(Cuda, TestRand01_CUDA) { - KernelScope kernel_scope; const int num_iter = 3; const int block_count = 16; const int block_size = 128; - Tensor* c = Compute( + Tensor c = Compute( "c", { {num_iter, "n"}, @@ -383,11 +377,10 @@ TEST(Cuda, TestRand01_CUDA) { } TEST(Cuda, DynamicShapeSplit_CUDA) { - KernelScope ks; constexpr int N = 4096; VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); - Tensor* b = Compute( + Tensor b = Compute( "b", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); LoopNest l({b}); ForPtr inner; @@ -434,7 +427,6 @@ TEST(Cuda, DynamicShapeSplit_CUDA) { TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { const static int N = 1024; - KernelScope kernel_scope; Placeholder data_buf("data", kFloat, {N}); Placeholder output_buf("output", kFloat, {1}); @@ -501,7 +493,6 @@ TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { const static int N = 1024; - KernelScope kernel_scope; // This test does the following reduction: // clang-format off @@ -578,8 +569,6 @@ TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { } TEST(Cuda, NoThreadIdxWrite_1_CUDA) { - KernelScope kernel_scope; - // This test does the following reduction: // // for k in 0..1: // block-idx @@ -676,7 +665,6 @@ TEST(Cuda, NoThreadIdxWrite_1_CUDA) { TEST(Cuda, SharedMemReduce_1_CUDA) { // FIXME: this test is flaky in CI. - KernelScope kernel_scope; // This test does the following: // for k in 0..1: // block-idx // alloc(c, 64) @@ -814,7 +802,6 @@ TEST(Cuda, SharedMemReduce_1_CUDA) { } TEST(Cuda, LocalMemReduce_1_CUDA) { - KernelScope kernel_scope; // This test does the following: // for k in 0..1: // block-idx // b(k) = 0 @@ -925,19 +912,18 @@ TEST(Cuda, LocalMemReduce_1_CUDA) { } TEST(Cuda, HalfSupport_CUDA) { - KernelScope ks; auto half = ToDtype(); Placeholder a("a", half, {4}); - Tensor* b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { + Tensor b = Compute("b", {{4, "n"}}, [&](const VarHandle& i) { return Cast::make(half, ExprHandle(2.0f) * a.load(i)); }); - Tensor* c = Compute("c", {{4, "n"}}, [&](const VarHandle& i) { - return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b->load(i)); + Tensor c = Compute("c", {{4, "n"}}, [&](const VarHandle& i) { + return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); }); - Tensor* d = Compute("d", {{4, "n"}}, [&](const VarHandle& i) { - return Cast::make(half, c->load(i)); + Tensor d = Compute("d", {{4, "n"}}, [&](const VarHandle& i) { + return Cast::make(half, c.load(i)); }); LoopNest l({b, c, d}); @@ -983,10 +969,9 @@ TEST(Cuda, HalfSupport_CUDA) { } TEST(Cuda, HalfPropagation_CUDA) { - KernelScope kernel_scope; auto half = ToDtype(); Placeholder a("a", half, {4}); - Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { + Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { return Max::make(a.load(i), ExprHandle(alloc(0)), true); }); @@ -1032,11 +1017,10 @@ TEST(Cuda, HalfPropagation_CUDA) { } TEST(Cuda, UnusedHalfArgument_CUDA) { - KernelScope kernel_scope; Placeholder a("a", kFloat, {4}); auto half = ToDtype(); Placeholder b("b", half, {4}); - Tensor* relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { + Tensor relu = Compute("relu", {{4, "n"}}, [&](const VarHandle& i) { return Max::make(a.load(i), ExprHandle(alloc(0)), true); }); @@ -1089,7 +1073,6 @@ TEST(Cuda, UnusedHalfArgument_CUDA) { } TEST(Cuda, PrioritizeDependents_CUDA) { - KernelScope kernel_scope; Placeholder a("a", kFloat, {10}); Placeholder b("b", kFloat, {12}); Placeholder c("c", kFloat, {12}); @@ -1163,15 +1146,14 @@ TEST(Cuda, PrioritizeDependents_CUDA) { /// Tests the case where there are two loops which have different extents bound /// to the same block dimension. We must mask the smaller extent loop body. TEST(Cuda, MaskBlockDim_CUDA) { - KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; Placeholder a_buf("a", kFloat, {A_SIZE}); Placeholder b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + b_buf.load(i); }); @@ -1256,15 +1238,14 @@ TEST(Cuda, MaskBlockDim_CUDA) { /// to the same thread dimension. This is the same as the above - the smaller /// rank write should be masked. But this time we also need to syncthreads. TEST(Cuda, MaskThreadDim_CUDA) { - KernelScope kernel_scope; int A_SIZE = 50; int B_SIZE = 100; Placeholder a_buf("a", kFloat, {A_SIZE}); Placeholder b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i / 2) + b_buf.load(i); }); @@ -1351,15 +1332,14 @@ TEST(Cuda, MaskThreadDim_CUDA) { // Note: this is an extremely dumb pattern which we should never see, but is a // useful edge case to make sure we've got things covered. TEST(Cuda, MaskMultiBlockDim_CUDA) { - KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; Placeholder a_buf("a", kFloat, {A_SIZE}); Placeholder b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + b_buf.load(i); }); @@ -1445,15 +1425,14 @@ TEST(Cuda, MaskMultiBlockDim_CUDA) { // Note: this is an extremely dumb pattern which we should never see, but is a // useful edge case to make sure we've got things covered. TEST(Cuda, MaskBlockAndThreadDim_CUDA) { - KernelScope kernel_scope; int A_SIZE = 100; int B_SIZE = 50; Placeholder a_buf("a", kFloat, {A_SIZE}); Placeholder b_buf("b", kFloat, {B_SIZE}); - Tensor* c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{A_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); - Tensor* d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { + Tensor d = Compute("d", {{B_SIZE, "i"}}, [&](const VarHandle& i) { return a_buf.load(i) + b_buf.load(i); }); @@ -1537,23 +1516,22 @@ TEST(Cuda, MaskBlockAndThreadDim_CUDA) { /// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In /// this case all writes with a rank smaller than the max should be masked. TEST(Cuda, MaskMultiDim_CUDA) { - KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 100; int B_SIZE = 50; Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); - Tensor* c = Compute( + Tensor c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return ExprHandle(2) * a_buf.load(i, j); }); - Tensor* d = Compute( + Tensor d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->load(i, j * 2) + b_buf.load(i, j); + return c.load(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -1575,10 +1553,10 @@ TEST(Cuda, MaskMultiDim_CUDA) { const std::string& verification_pattern = R"IR( # CHECK-NOT: if ( -# CHECK: C[100 * blockIdx.x + threadIdx.x] = +# CHECK: C[threadIdx.x + 100 * blockIdx.x] = # CHECK: __syncthreads(); # CHECK: if (threadIdx.x<50 -# CHECK: D[50 * blockIdx.x + threadIdx.x] =)IR"; +# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1667,23 +1645,22 @@ TEST(Cuda, MaskMultiDim_CUDA) { // In this case both stores must be masked against the extent of the other loop, // incase it is larger. TEST(Cuda, MaskMultiDimSymbolic_CUDA) { - KernelScope kernel_scope; VarHandle OUTER_SIZE("OUTER_SIZE", kInt); VarHandle A_SIZE("A_SIZE", kInt); VarHandle B_SIZE("B_SIZE", kInt); Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); - Tensor* c = Compute( + Tensor c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return ExprHandle(2) * a_buf.load(i, j); }); - Tensor* d = Compute( + Tensor d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->load(i, j * 2) + b_buf.load(i, j); + return c.load(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -1705,10 +1682,10 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) { const std::string& verification_pattern = R"IR( # CHECK: if (threadIdx.x 0. // Note: this is a bit degenerate no one would actually write this for perf. TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { - KernelScope kernel_scope; int OUTER_SIZE = 10; int A_SIZE = 30; int B_SIZE = 15; Placeholder a_buf("a", kFloat, {OUTER_SIZE, A_SIZE}); Placeholder b_buf("b", kFloat, {OUTER_SIZE, B_SIZE}); - Tensor* c = Compute( + Tensor c = Compute( "C", {{OUTER_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return ExprHandle(2) * a_buf.load(i, j); }); - Tensor* d = Compute( + Tensor d = Compute( "D", {{OUTER_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->load(i, j * 2) + b_buf.load(i, j); + return c.load(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -2119,7 +2093,7 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { const std::string& verification_pattern = R"IR( # CHECK: if (threadIdx.y<1 -# CHECK: C[30 * blockIdx.x + threadIdx.x] = +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = # CHECK: __syncthreads(); # CHECK: if (threadIdx.x<1 # CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; @@ -2211,24 +2185,23 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { // the second loop is smaller in both cases - the second store must be masked // for both the block and thread dimension. TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { - KernelScope kernel_scope; int OUTER_A_SIZE = 10; int OUTER_B_SIZE = 5; int A_SIZE = 30; int B_SIZE = 15; Placeholder a_buf("a", kFloat, {OUTER_A_SIZE, A_SIZE}); Placeholder b_buf("b", kFloat, {OUTER_B_SIZE, B_SIZE}); - Tensor* c = Compute( + Tensor c = Compute( "C", {{OUTER_A_SIZE, "i"}, {A_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return ExprHandle(2) * a_buf.load(i, j); }); - Tensor* d = Compute( + Tensor d = Compute( "D", {{OUTER_B_SIZE, "i"}, {B_SIZE, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return c->load(i, j * 2) + b_buf.load(i, j); + return c.load(i, j * 2) + b_buf.load(i, j); }); LoopNest l({c, d}); @@ -2250,7 +2223,7 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { const std::string& verification_pattern = R"IR( # CHECK-NOT: if ( -# CHECK: C[30 * blockIdx.x + threadIdx.x] = +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = # CHECK: __syncthreads(); # CHECK: if (blockIdx.x<5 # CHECK: if (threadIdx.x<15 diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp index 7c234fb95cdb1..d2405353e8301 100644 --- a/test/cpp/tensorexpr/test_expr.cpp +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -24,7 +24,6 @@ using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; TEST(Expr, BasicValueTest) { - KernelScope kernel_scope; ExprHandle a = IntImm::make(2), b = IntImm::make(3); ExprHandle c = Add::make(a, b); SimpleIRExprEval eval(c); @@ -32,7 +31,6 @@ TEST(Expr, BasicValueTest) { } TEST(Expr, BasicValueTest02) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); @@ -43,7 +41,6 @@ TEST(Expr, BasicValueTest02) { } TEST(Expr, LetTest01) { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); SimpleIRExprEval eval(body); @@ -52,7 +49,6 @@ TEST(Expr, LetTest01) { } TEST(Expr, LetTest02) { - KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = @@ -64,7 +60,6 @@ TEST(Expr, LetTest02) { } TEST(Expr, LetStmtTest01) { - KernelScope kernel_scope; Placeholder a_buf("a", kFloat, {1}); Placeholder b_buf("b", kFloat, {1}); @@ -88,7 +83,6 @@ TEST(Expr, LetStmtTest01) { } TEST(Expr, IntTest) { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); SimpleIRExprEval eval(body); @@ -97,7 +91,6 @@ TEST(Expr, IntTest) { } TEST(Expr, FloatTest) { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); SimpleIRExprEval eval(body); @@ -106,7 +99,6 @@ TEST(Expr, FloatTest) { } TEST(Expr, ByteTest) { - KernelScope kernel_scope; VarHandle x("x", kByte); ExprHandle body = ExprHandle((uint8_t)2) + (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); @@ -116,7 +108,6 @@ TEST(Expr, ByteTest) { } TEST(Expr, CharTest) { - KernelScope kernel_scope; VarHandle x("x", kChar); ExprHandle body = ExprHandle((int8_t)2) + (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); @@ -126,7 +117,6 @@ TEST(Expr, CharTest) { } TEST(Expr, ShortTest) { - KernelScope kernel_scope; VarHandle x("x", kShort); ExprHandle body = ExprHandle((int16_t)2) + (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); @@ -136,7 +126,6 @@ TEST(Expr, ShortTest) { } TEST(Expr, LongTest) { - KernelScope kernel_scope; VarHandle x("x", kLong); ExprHandle body = ExprHandle((int64_t)2) + (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); @@ -146,7 +135,6 @@ TEST(Expr, LongTest) { } TEST(Expr, HalfTest) { - KernelScope kernel_scope; VarHandle x("x", kHalf); ExprHandle body = ExprHandle((at::Half)2) + (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); @@ -156,7 +144,6 @@ TEST(Expr, HalfTest) { } TEST(Expr, DoubleTest) { - KernelScope kernel_scope; VarHandle x("x", kDouble); ExprHandle body = ExprHandle((double)2) + (x * ExprHandle((double)3) + ExprHandle((double)4)); @@ -166,7 +153,6 @@ TEST(Expr, DoubleTest) { } TEST(Expr, VectorAdd01) { - KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; @@ -212,7 +198,6 @@ TEST(Expr, VectorAdd01) { } TEST(Expr, CompareSelectEQ) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -251,7 +236,6 @@ TEST(Expr, CompareSelectDtypes) { // This test constructs a CompareSelect expression where the input dtype is // different from the output dtype and verifies that it works correctly: // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -290,7 +274,6 @@ TEST(Expr, CompareSelectDtypes) { } TEST(Expr, IntrinsicsDtypes) { - KernelScope kernel_scope; constexpr int N = 256; Placeholder a(BufHandle("A", {N}, kDouble)); Placeholder b(BufHandle("B", {N}, kDouble)); @@ -312,7 +295,6 @@ TEST(Expr, IntrinsicsDtypes) { } TEST(Expr, Substitute01) { - KernelScope kernel_scope; VarPtr x = alloc("x", kFloat); VarPtr y = alloc("y", kFloat); ExprPtr e = @@ -334,7 +316,6 @@ TEST(Expr, Substitute01) { } TEST(Expr, Math01) { - KernelScope kernel_scope; ExprHandle v = sin(ExprHandle(1.0f)); std::ostringstream oss; @@ -348,7 +329,6 @@ TEST(Expr, Math01) { } TEST(Expr, UnaryMath01) { - KernelScope kernel_scope; struct TestConfig { std::function func; std::function ref_func; @@ -416,7 +396,6 @@ TEST(Expr, UnaryMath01) { } TEST(Expr, BinaryMath01) { - KernelScope kernel_scope; struct TestConfig { std::function func; std::function ref_func; @@ -440,7 +419,6 @@ TEST(Expr, BinaryMath01) { } TEST(Expr, LogicalOps01) { - KernelScope kernel_scope; ExprHandle a(23); ExprHandle b(11); ExprHandle c(0.72f); @@ -473,7 +451,6 @@ TEST(Expr, LogicalOps01) { } TEST(Expr, LogicalOps02) { - KernelScope kernel_scope; ExprHandle a(23); ExprHandle b(11); ExprHandle c(0.72f); @@ -492,7 +469,6 @@ TEST(Expr, LogicalOps02) { } TEST(Expr, LogicalOps03) { - KernelScope kernel_scope; ExprHandle a(23); ExprHandle b(11); ExprHandle c(0.72f); @@ -550,7 +526,6 @@ TEST(Expr, LogicalOps03) { } TEST(Expr, BitwiseOps) { - KernelScope kernel_scope; ExprHandle a(59); ExprHandle b(11); ExprHandle c(101); @@ -562,7 +537,6 @@ TEST(Expr, BitwiseOps) { } TEST(Expr, DynamicShapeAdd) { - KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); @@ -582,7 +556,6 @@ TEST(Expr, DynamicShapeAdd) { } void testCond01() { - KernelScope kernel_scope; const int N = 16; PaddedBuffer a_v(N); Placeholder a_buf("a", kFloat, {N}); @@ -606,7 +579,6 @@ void testCond01() { } void testIfThenElse01() { - KernelScope kernel_scope; ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); std::ostringstream oss; @@ -618,7 +590,6 @@ void testIfThenElse01() { } void testIfThenElse02() { - KernelScope kernel_scope; ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); std::ostringstream oss; @@ -630,7 +601,6 @@ void testIfThenElse02() { } void testIfThenElse03() { - KernelScope kernel_scope; ExprHandle v = ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); @@ -643,7 +613,6 @@ void testIfThenElse03() { } void testStmtClone() { - KernelScope kernel_scope; const int N = 16; Placeholder a_buf("a", kInt, {N}); diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp index 9ae99ca5d3b2f..176158e7fe13a 100644 --- a/test/cpp/tensorexpr/test_external_calls.cpp +++ b/test/cpp/tensorexpr/test_external_calls.cpp @@ -20,8 +20,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(ExternalCall, Conv2d_float) { - KernelScope kernel_scope; - Placeholder Input("Input", kFloat, {1, 3, 224, 224}); Placeholder Weight("Weight", kFloat, {16, 3, 3, 3}); Placeholder Bias("Bias", kFloat, {16}); @@ -31,7 +29,7 @@ TEST(ExternalCall, Conv2d_float) { int64_t dilation = 1; int64_t groups = 1; - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -84,7 +82,6 @@ TEST(ExternalCall, Conv2d_float) { TEST(ExternalCall, Conv2d_int) { // A similar test, but now using kInt tensors - KernelScope kernel_scope; Placeholder Input("Input", kInt, {1, 3, 224, 224}); Placeholder Weight("Weight", kInt, {16, 3, 3, 3}); @@ -95,7 +92,7 @@ TEST(ExternalCall, Conv2d_int) { int64_t dilation = 1; int64_t groups = 1; - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -147,13 +144,11 @@ TEST(ExternalCall, Conv2d_int) { } TEST(ExternalCall, Conv2d_nobias_noargs) { - KernelScope kernel_scope; - Placeholder Input("Input", kFloat, {1, 16, 112, 112}); Placeholder Weight("Weight", kFloat, {16, 16, 1, 1}); BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -194,8 +189,6 @@ TEST(ExternalCall, Conv2d_nobias_noargs) { } TEST(ExternalCall, Addmm_float) { - KernelScope kernel_scope; - Placeholder Input("Input", kFloat, {100, 300}); Placeholder Mat1("Mat1", kFloat, {100, 200}); Placeholder Mat2("Mat2", kFloat, {200, 300}); @@ -203,7 +196,7 @@ TEST(ExternalCall, Addmm_float) { int64_t beta = 2; int64_t alpha = 2; - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -252,8 +245,6 @@ TEST(ExternalCall, Addmm_float) { TEST(ExternalCall, Prepacked_Linear_float) { using namespace at::native::xnnpack; - KernelScope kernel_scope; - Placeholder Input("Input", kFloat, {100, 200}); BufHandle ResultBuf("Result", {100, 300}, kFloat); @@ -283,7 +274,7 @@ TEST(ExternalCall, Prepacked_Linear_float) { weight, bias, c10::optional(), c10::optional()); Placeholder DummyPrepacked("DummyPrepacked", kFloat, {1}); - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -317,8 +308,6 @@ TEST(ExternalCall, Prepacked_Linear_float) { TEST(ExternalCall, Prepacked_Conv2d_float) { using namespace at::native::xnnpack; - KernelScope kernel_scope; - Placeholder Input("Input", kFloat, {1, 3, 224, 224}); BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); int64_t stride = 2; @@ -370,7 +359,7 @@ TEST(ExternalCall, Prepacked_Conv2d_float) { c10::optional()); Placeholder DummyPrepacked("DummyPrepacked", kFloat, {1}); - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -404,7 +393,6 @@ TEST(ExternalCall, Prepacked_Conv2d_float) { #endif // USE_XNNPACK TEST(ExternalCall, BinaryFloat) { - KernelScope kernel_scope; using TensorFunc = std::function; using Test = std::tuple< std::vector, @@ -431,7 +419,7 @@ TEST(ExternalCall, BinaryFloat) { Placeholder B("", kFloat, toExprHandleVec(bShape)); BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, @@ -479,7 +467,6 @@ TEST(ExternalCall, BinaryFloat) { } TEST(ExternalCall, UnaryFloat) { - KernelScope kernel_scope; using TensorFunc = std::function; auto toExprHandleVec = [](std::vector v) { auto intV = std::vector(v.begin(), v.end()); @@ -516,7 +503,7 @@ TEST(ExternalCall, UnaryFloat) { Placeholder A("A", kFloat, toExprHandleVec(aShape)); BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); - Tensor* Result = new Tensor( + Tensor Result = Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, externCallName, {BufHandle(A.data())}, externCallArgs)); @@ -561,19 +548,18 @@ TEST(ExternalCall, UnaryFloat) { TEST(ExternalCall, ComputeInterop) { // This test verifies that Tensors using external calls can be used by and can // use Tensors built with Compute API. - KernelScope kernel_scope; - BufHandle ConvResultBuf("ConvResult", {1, 16, 112, 112}, kFloat); - BufHandle MatmulResultBuf("MatmulResult", {1, 16, 112, 112}, kFloat); + BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); + BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); - Tensor* Input = Compute( + Tensor Input = Compute( "Input", - {{1, "n"}, {16, "c"}, {112, "h"}, {112, "w"}}, + {{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}}, [&](const VarHandle& n, const VarHandle& c, const VarHandle& h, const VarHandle& w) { return FloatImm::make(5.0f); }); - Tensor* Weight = Compute( + Tensor Weight = Compute( "Weight", {{16, "n"}, {16, "c"}, {1, "kh"}, {1, "kw"}}, [&](const VarHandle& n, @@ -581,28 +567,28 @@ TEST(ExternalCall, ComputeInterop) { const VarHandle& h, const VarHandle& w) { return FloatImm::make(6.0f); }); - Tensor* ConvResult = new Tensor( + Tensor ConvResult = Tensor( ConvResultBuf.node(), ExternalCall::make( ConvResultBuf, "nnc_aten_conv2d", - {BufHandle(Input->buf()), BufHandle(Weight->buf())}, + {BufHandle(Input.buf()), BufHandle(Weight.buf())}, {})); - Tensor* MatmulResult = new Tensor( + Tensor MatmulResult = Tensor( MatmulResultBuf.node(), ExternalCall::make( MatmulResultBuf, "nnc_aten_matmul", - {BufHandle(ConvResult->buf()), BufHandle(ConvResult->buf())}, + {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, {})); - Tensor* Result = Compute( + Tensor Result = Compute( "Result", - {{1, "n"}, {16, "c"}, {112, "h"}, {112, "w"}}, + {{1, "n"}, {16, "c"}, {32, "h"}, {32, "w"}}, [&](const VarHandle& n, const VarHandle& c, const VarHandle& h, const VarHandle& w) { - return ConvResult->load(n, c, h, w) + MatmulResult->load(n, c, h, w); + return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); }); LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); @@ -619,18 +605,18 @@ TEST(ExternalCall, ComputeInterop) { .layout(at::kStrided) .device(at::kCPU) .requires_grad(false); - at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; + at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; at::Tensor t = at::conv2d(input, weight); at::Tensor t2 = at::matmul(t, t); at::Tensor ref = t + t2; at::Tensor nnc_result; - std::vector input_buf(1 * 16 * 112 * 112, 5.f); + std::vector input_buf(1 * 16 * 32 * 32, 5.f); std::vector weight_buf(16 * 16 * 1 * 1, 6.f); - std::vector conv_result_buf(1 * 16 * 112 * 112, -1.f); - std::vector matmul_result_buf(1 * 16 * 112 * 112, -1.f); - std::vector result_buf(1 * 16 * 112 * 112, -1.f); + std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector result_buf(1 * 16 * 32 * 32, -1.f); #ifdef TORCH_ENABLE_LLVM LLVMCodeGen llvm_codegen( @@ -638,7 +624,7 @@ TEST(ExternalCall, ComputeInterop) { llvm_codegen.call( {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); ASSERT_TRUE(at::allclose(nnc_result, ref)); #endif @@ -647,42 +633,41 @@ TEST(ExternalCall, ComputeInterop) { ir_eval.call( {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); - nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); ASSERT_TRUE(at::allclose(nnc_result, ref)); } TEST(ExternalCall, Inlining) { // This test verifies that Tensors using external calls can be used by and // can use Tensors built with Compute API. - KernelScope kernel_scope; BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); - Tensor* A = Compute( + Tensor A = Compute( "A", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return FloatImm::make(5.0f); }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return FloatImm::make(4.0f); }); - Tensor* MatmulResult = new Tensor( + Tensor MatmulResult = Tensor( MatmulResultBuf.node(), ExternalCall::make( MatmulResultBuf, "nnc_aten_matmul", - {BufHandle(A->buf()), BufHandle(B->buf())}, + {BufHandle(A.buf()), BufHandle(B.buf())}, {})); - Tensor* Result = Compute( + Tensor Result = Compute( "Result", {{8, "i"}, {8, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return MatmulResult->load(i, j) + FloatImm::make(3.0f); + return MatmulResult.load(i, j) + FloatImm::make(3.0f); }); StmtPtr root_stmt = alloc(std::vector( - {A->stmt(), B->stmt(), MatmulResult->stmt(), Result->stmt()})); - LoopNest l(root_stmt, {Result->buf()}); + {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); + LoopNest l(root_stmt, {Result.buf()}); // Inlining should not inline anything here since all Bufs are either // defined or used in ExternalCalls diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp index 3175d7f142524..e5a237f5f7541 100644 --- a/test/cpp/tensorexpr/test_graph_opt.cpp +++ b/test/cpp/tensorexpr/test_graph_opt.cpp @@ -45,7 +45,6 @@ TEST_F(GraphOpt, OptimizeCat) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // The `aten::log` op must be moved to the inputs of `aten::cat`. @@ -88,7 +87,6 @@ TEST_F(GraphOpt, OptimizeCat2) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // The `aten::log` and `aten::tanh` ops must be moved to the inputs of @@ -137,7 +135,6 @@ TEST_F(GraphOpt, OptimizeCat3) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // The `aten::tanh` op must be moved to the inputs of `aten::cat`. @@ -183,7 +180,6 @@ TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // The `aten::tanh` op must be moved to the inputs of `aten::cat`. @@ -227,7 +223,6 @@ TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // No transformation should have happened because the `aten::cat` op performs @@ -257,7 +252,6 @@ TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // No transformation is expected since the consumers of cat are not @@ -290,7 +284,6 @@ TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { torch::jit::parseIR(graph_string, g.get()); g->lint(); - KernelScope kernel_scope; TensorExprKernel kernel(g); // No transformation is expected since the consumers of cat are not diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp index 76d9247579d7c..820f12689acca 100644 --- a/test/cpp/tensorexpr/test_ir_printer.cpp +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -17,7 +17,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(IRPrinter, BasicValueTest) { - KernelScope kernel_scope; ExprHandle a = IntImm::make(2), b = IntImm::make(3); ExprHandle c = Add::make(a, b); @@ -27,7 +26,6 @@ TEST(IRPrinter, BasicValueTest) { } TEST(IRPrinter, BasicValueTest02) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); @@ -40,7 +38,6 @@ TEST(IRPrinter, BasicValueTest02) { } TEST(IRPrinter, CastTest) { - KernelScope kernel_scope; VarHandle x("x", kHalf); VarHandle y("y", kFloat); ExprHandle body = ExprHandle(2.f) + @@ -52,34 +49,33 @@ TEST(IRPrinter, CastTest) { } TEST(IRPrinter, FunctionName) { - KernelScope kernel_scope; int M = 4; int N = 20; - Tensor* producer = Compute( + Tensor producer = Compute( "producer", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { return m * n; }); - Tensor* chunk_0 = Compute( + Tensor chunk_0 = Compute( "chunk", {{M, "m"}, {N / 2, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer->load(m, n); + return producer.load(m, n); }); - Tensor* chunk_1 = Compute( + Tensor chunk_1 = Compute( "chunk", {{M, "m"}, {N / 2, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { - return producer->load(m, n + ExprHandle(N / 2)); + return producer.load(m, n + ExprHandle(N / 2)); }); - Tensor* consumer = Compute( + Tensor consumer = Compute( "consumer", {{M, "i"}, {N / 2, "j"}}, [&](const ExprHandle& i, const ExprHandle& j) { - return i * chunk_1->load(i, j); + return i * chunk_1.load(i, j); }); LoopNest l({chunk_0, chunk_1, consumer}); diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp index 2c91d8b24b253..cbe15502ad1f9 100644 --- a/test/cpp/tensorexpr/test_ir_verifier.cpp +++ b/test/cpp/tensorexpr/test_ir_verifier.cpp @@ -17,7 +17,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(IRVerifier, BitwiseOps) { - KernelScope kernel_scope; VarPtr X = alloc("x", kInt); VarPtr Y = alloc("y", kFloat); { @@ -48,7 +47,6 @@ TEST(IRVerifier, BitwiseOps) { } TEST(IRVerifier, CompareSelect) { - KernelScope kernel_scope; ExprPtr X = alloc(1); ExprPtr Y = alloc(3.14f); { @@ -64,7 +62,6 @@ TEST(IRVerifier, CompareSelect) { } TEST(IRVerifier, Ramp) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); VarPtr J = alloc("j", kFloat); { @@ -75,7 +72,6 @@ TEST(IRVerifier, Ramp) { } TEST(IRVerifier, Load) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); VarPtr J = alloc("j", kLong); VarPtr K = alloc("k", kFloat); @@ -105,7 +101,6 @@ TEST(IRVerifier, Load) { } TEST(IRVerifier, IfThenElse) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); VarPtr J = alloc("j", kLong); VarPtr K = alloc("k", kFloat); @@ -130,7 +125,6 @@ TEST(IRVerifier, IfThenElse) { } TEST(IRVerifier, For) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); VarPtr J = alloc("j", kInt); StmtPtr body = alloc(std::vector({})); @@ -143,7 +137,6 @@ TEST(IRVerifier, For) { } TEST(IRVerifier, Block) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); { @@ -160,7 +153,6 @@ TEST(IRVerifier, Block) { } TEST(IRVerifier, Store) { - KernelScope kernel_scope; VarPtr I = alloc("i", kInt); VarPtr J = alloc("j", kLong); VarPtr K = alloc("k", kFloat); diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 8f36f54395f49..f4d3b16b964f2 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -39,7 +39,6 @@ TEST_F(Kernel, InliningIntermediates) { %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) return (%5))IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); TensorExprKernel k(graph); @@ -63,7 +62,6 @@ TEST_F(Kernel, InliningIntermediates) { continue; } - KernelScope kernel_scope; TemplateEnv env; env.s("device", use_cuda ? "cuda:0" : "cpu"); const auto graph_string = format(graph_template, env); @@ -88,8 +86,6 @@ TEST_F(Kernel, InliningIntermediates) { } TEST_F(Kernel, _1) { - KernelScope kernel_scope; - const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), %1 : Float(5, 3, strides=[3, 1], device=cpu)): @@ -127,8 +123,6 @@ TEST_F(Kernel, _1) { } TEST_F(Kernel, _2) { - KernelScope kernel_scope; - const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), %1 : Float(5, 3, strides=[1, 5], device=cpu)): @@ -167,8 +161,6 @@ TEST_F(Kernel, _2) { } TEST_F(Kernel, _3) { - KernelScope kernel_scope; - const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), %1 : Float(5, 3, strides=[12, 2], device=cpu)): @@ -206,14 +198,56 @@ TEST_F(Kernel, _3) { } } +TEST_F(Kernel, Huge) { + const auto graph_string = R"IR( + graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=0]() + %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) + %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + const std::string& verification_pattern = "# CHECK: 4000000000"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST_F(Kernel, ParallelStrided) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), + %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): + %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) + .index( + {Slice(None, None, 2), + Slice(None, None, 2), + Slice(None, None, 2)}); + auto ref = a * (a * b); + auto o = at::zeros_like(ref); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + TEST_F(Kernel, DISABLED_Shape_Inference) { // disabled: doesn't do stride propagation, and isn't being used currently // Test TensorExpr shape inference capabilities: it should only require shapes // for the inputs { - KernelScope kernel_scope; - const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), %1 : Float(5, 3, strides=[12, 2], device=cpu)): @@ -251,8 +285,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { } } { - KernelScope kernel_scope; - const auto graph_string = R"IR( graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), %1 : Float(8, 8, strides=[8, 1], device=cpu)): @@ -292,7 +324,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { } { // Test that shape inference handles aten::unsqueeze - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%a : Float(4, 2, strides=[2, 1], device=cpu), @@ -355,7 +386,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { } { // Test that shape inference handles aten::cat - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), @@ -409,7 +439,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { } { // Test that we throw an error when input list for aten::cat is empty - KernelScope kernel_scope; const auto graph_string = R"IR( graph(): @@ -427,7 +456,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { } { // Test that we throw an error when 'dim' passed to aten::cat is invalid - KernelScope kernel_scope; const auto ir_dim_99 = R"IR( graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), @@ -458,7 +486,6 @@ TEST_F(Kernel, DISABLED_Shape_Inference) { TEST_F(Kernel, CatInputTypesPromotion) { { // Test that we properly promote input types for aten::cat - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), @@ -676,7 +703,6 @@ TEST_F(Kernel, SumAllAxes) { auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - KernelScope kernel_scope; TemplateEnv env; env.s("dtype", dtypeConstant(scalar_type)); if (scalar_type == ScalarType::Undefined) { @@ -745,7 +771,6 @@ TEST_F(Kernel, SumOneAxis) { for (int dim = -a.dim(); dim < a.dim(); ++dim) { for (bool keepdim : {false, true}) { for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { - KernelScope kernel_scope; TemplateEnv env; env.d("dim", dim); env.d("keepdim", keepdim); @@ -777,9 +802,9 @@ TEST_F(Kernel, SumOneAxis) { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: for (int v = 0; v < +# CHECK: for (int64_t v = 0ll; v < # CHECK-NEXT: sum -# CHECK-NEXT: for (int v_1 = 0; v_1 < +# CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < # CHECK-NEXT: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -812,7 +837,6 @@ TEST_F(Kernel, SumMultipleAxes) { for (int dim1 = 0; dim1 < a.dim(); ++dim1) { for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { for (bool keepdim : {false, true}) { - KernelScope kernel_scope; TemplateEnv env; env.d("dim1", dim1); env.d("dim2", dim2); @@ -839,10 +863,10 @@ TEST_F(Kernel, SumMultipleAxes) { // Check the IR we produced const std::string& verification_pattern = R"IR( -# CHECK: int v = 0 -# CHECK: int v_1 = 0 -# CHECK: int v_2 = 0 -# CHECK: int v_3 = 0 +# CHECK: int64_t v = 0 +# CHECK: int64_t v_1 = 0 +# CHECK: int64_t v_2 = 0 +# CHECK: int64_t v_3 = 0 # CHECK: sum)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -888,7 +912,6 @@ TEST_F(Kernel, Softmax2D) { auto other_dim = (softmax_dim + 1) % a.dim(); auto ref = log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); @@ -964,7 +987,6 @@ TEST_F(Kernel, Softmax3D) { auto ref = log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); @@ -1046,7 +1068,6 @@ TEST_F(Kernel, Softmax4D) { auto ref = log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); - KernelScope kernel_scope; TemplateEnv env; env.d("dim", softmax_dim); env.s("op", log_softmax ? "log_softmax" : "softmax"); @@ -1090,8 +1111,6 @@ TEST_F(Kernel, Softmax4D) { } TEST_F(Kernel, InlineProducerIntoReduction) { - KernelScope kernel_scope; - // Inline producer (mul) into reduction (sum). const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), @@ -1112,8 +1131,8 @@ TEST_F(Kernel, InlineProducerIntoReduction) { // We should have only one loop in the end. const std::string& verification_pattern = R"IR( - # CHECK: for (int v = 0; v < 5; - # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK: for (int64_t v = 0ll; v < 5 + # CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < 3 # CHECK-NEXT: sum # CHECK-NOT: for)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1129,8 +1148,6 @@ TEST_F(Kernel, InlineProducerIntoReduction) { } TEST_F(Kernel, InlineReductionIntoConsumer) { - KernelScope kernel_scope; - // Inline producer (mul %2) into reduction (sum %4) but DO NOT // inline the reduction into consumer (mul %4). const auto graph_string = R"IR( @@ -1153,11 +1170,11 @@ TEST_F(Kernel, InlineReductionIntoConsumer) { // We should have two loops in the end. const std::string& verification_pattern = R"IR( - # CHECK: for (int v = 0; v < 5; - # CHECK-NEXT: for (int v_1 = 0; v_1 < 3; + # CHECK: for (int64_t v = 0ll; v < 5 + # CHECK-NEXT: for (int64_t v_1 = 0ll; v_1 < 3 # CHECK-NEXT: sum - # CHECK: for (int v_2 = 0; v_2 < 5; - # CHECK-NEXT: for (int v_3 = 0; v_3 < 3; + # CHECK: for (int64_t v_2 = 0ll; v_2 < 5 + # CHECK-NEXT: for (int64_t v_3 = 0ll; v_3 < 3 # CHECK-NEXT: aten_mul # CHECK-NOT: for)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1179,7 +1196,6 @@ TEST_F(Kernel, SanitizeNames_CUDA) { %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) return (%4))IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); graph->inputs().at(0)->setDebugName("aten::add:"); @@ -1195,6 +1211,43 @@ TEST_F(Kernel, SanitizeNames_CUDA) { ASSERT_TRUE(at::allclose(o, ref)); } +TEST_F(Kernel, SanitizeConstants_CUDA) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): + %none : NoneType = prim::Constant() + %size : int = prim::Constant[value=16]() + %sizes : int[] = prim::ListConstruct(%size, %size) + %30 : Device = prim::Constant[value="cuda"]() + %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) + %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we insert a call to + // aten::ones and then const-prop it + ConstantPropagation(graph); + + // We set the name of the constant to include special characters that are + // not allowed. This should be fixed by the sanitizer in TensorExprKernel. + graph->nodes().front()->output()->setDebugName("illegal.name"); + + // Check if we have a constant node with illegal name in the graph. + auto const_node = graph->nodes().front(); + ASSERT_EQ(const_node->kind(), prim::Constant); + ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + TEST_F(Kernel, ConstantTensors) { const auto graph_string = R"IR( graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): @@ -1204,7 +1257,6 @@ TEST_F(Kernel, ConstantTensors) { %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) return (%z))IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); // IRParser doesn't support tensor constants, so we insert a call to @@ -1237,7 +1289,6 @@ TEST_F(Kernel, ConstantTensorsNonContiguous) { %y : Tensor = aten::t(%y_t) %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) return (%z))IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); // IRParser doesn't support tensor constants, so we generate several aten @@ -1261,7 +1312,6 @@ TEST_F(Kernel, ConstantTensorsNonContiguous) { TEST_F(Kernel, RunFast) { #ifdef TORCH_ENABLE_LLVM // TODO: Implement call_raw in IREval and remove the ifdef - KernelScope kernel_scope; const auto graph_string = R"IR( graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), @@ -1301,7 +1351,6 @@ TEST_F(Kernel, CodegenInspection) { %y : Tensor = aten::t(%y_t) %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) return (%z))IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); // IRParser doesn't support tensor constants, so we generate several aten @@ -1329,7 +1378,7 @@ TEST_F(Kernel, CodegenInspection) { #endif } -Tensor* lowerNanToNum( +Tensor lowerNanToNum( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType, @@ -1353,7 +1402,6 @@ TEST_F(Kernel, CustomLowering) { %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) return (%y) )IR"; - KernelScope kernel_scope; auto graph = std::make_shared(); parseIR(graph_string, &*graph); diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 3776329a86a51..0e5cf5eb03a3d 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -36,7 +36,6 @@ using LLVMExprEval = ExprEval; #define IMM_TEST(Type, Name, Val) \ TEST(LLVM, Name##ImmTest) { \ - KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ LLVMExprEval cg(a); \ if (std::is_floating_point()) { \ @@ -50,7 +49,6 @@ TEST_LLVM_SCALAR_TYPES(IMM_TEST) #define ADD_TEST(Type, Name, Val) \ TEST(LLVM, Name##AddTest) { \ - KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ auto b = Name##Imm::make(Val * 2); \ auto c = Add::make(a, b); \ @@ -66,7 +64,6 @@ TEST_LLVM_SCALAR_TYPES(ADD_TEST) #define SUB_TEST(Type, Name, Val) \ TEST(LLVM, Name##SubTest) { \ - KernelScope kernel_scope; \ auto a = Name##Imm::make(Val * 2); \ auto b = Name##Imm::make(Val); \ auto c = Sub::make(a, b); \ @@ -82,7 +79,6 @@ TEST_LLVM_SCALAR_TYPES(SUB_TEST) #define MUL_TEST(Type, Name, Val) \ TEST(LLVM, Name##MulTest) { \ - KernelScope kernel_scope; \ auto a = Name##Imm::make(Val); \ auto b = Name##Imm::make((Type)4); \ auto c = Mul::make(a, b); \ @@ -98,7 +94,6 @@ TEST_LLVM_SCALAR_TYPES(MUL_TEST) #define DIV_TEST(Type, Name, Val) \ TEST(LLVM, Name##DivTest) { \ - KernelScope kernel_scope; \ auto a = Name##Imm::make((Type)6); \ auto b = Name##Imm::make((Type)3); \ auto c = Div::make(a, b); \ @@ -113,7 +108,6 @@ TEST_LLVM_SCALAR_TYPES(DIV_TEST) #undef DIV_TEST TEST(LLVM, IntToFloatCastTest) { - KernelScope kernel_scope; auto a = IntImm::make(2); auto b = Cast::make(kFloat, a); LLVMExprEval cg(b, {}); @@ -121,7 +115,6 @@ TEST(LLVM, IntToFloatCastTest) { } TEST(LLVM, FloatToIntCastTest) { - KernelScope kernel_scope; auto a = FloatImm::make(2.0); auto b = Cast::make(kInt, a); LLVMExprEval cg(b); @@ -129,7 +122,6 @@ TEST(LLVM, FloatToIntCastTest) { } TEST(LLVM, IntToLongCastTest) { - KernelScope kernel_scope; auto a = IntImm::make(12345); auto b = Cast::make(kLong, a); LLVMExprEval cg(b); @@ -137,7 +129,6 @@ TEST(LLVM, IntToLongCastTest) { } TEST(LLVM, ByteToCharCastTest) { - KernelScope kernel_scope; auto a = ByteImm::make(250); auto b = Cast::make(kChar, a); LLVMExprEval cg(b); @@ -145,7 +136,6 @@ TEST(LLVM, ByteToCharCastTest) { } TEST(LLVM, HalfToLongCastTest) { - KernelScope kernel_scope; auto a = HalfImm::make(2.0); auto b = Cast::make(kLong, a); LLVMExprEval cg(b); @@ -153,7 +143,6 @@ TEST(LLVM, HalfToLongCastTest) { } TEST(LLVM, ByteToDoubleCastTest) { - KernelScope kernel_scope; auto a = ByteImm::make(2); auto b = Cast::make(kDouble, a); LLVMExprEval cg(b); @@ -170,7 +159,6 @@ TEST(LLVM, BitCast) { // this is broken /*{ - KernelScope kernel_scope; at::Half k_; at::Half* k = &k_; *reinterpret_cast(k) = ref16; @@ -181,7 +169,6 @@ TEST(LLVM, BitCast) { }*/ { - KernelScope kernel_scope; float k = raw_bitcast(ref32); auto a = FloatImm::make(k); auto b = BitCast::make(kInt, a); @@ -190,7 +177,6 @@ TEST(LLVM, BitCast) { } { - KernelScope kernel_scope; double k = raw_bitcast(ref64); auto a = DoubleImm::make(k); auto b = BitCast::make(kLong, a); @@ -199,7 +185,6 @@ TEST(LLVM, BitCast) { } { - KernelScope kernel_scope; int64_t k = raw_bitcast(reff64); auto a = LongImm::make(k); auto b = BitCast::make(kDouble, a); @@ -208,7 +193,6 @@ TEST(LLVM, BitCast) { } { - KernelScope kernel_scope; int32_t k = raw_bitcast(reff32); auto a = IntImm::make(k); auto b = BitCast::make(kFloat, a); @@ -218,7 +202,6 @@ TEST(LLVM, BitCast) { } TEST(LLVM, fastLogFloat) { - KernelScope kernel_scope; const int kTotalSize = 128 * 128; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); Placeholder b_buf(BufHandle("B", {ExprHandle(kTotalSize)}, kFloat)); @@ -250,8 +233,6 @@ TEST(LLVM, fastLogFloat) { } TEST(LLVM, LetTest01) { - KernelScope kernel_scope; - Placeholder a(BufHandle("A", {1}, kFloat)); std::vector v = {1, 0}; std::vector args({v.data()}); @@ -267,8 +248,6 @@ TEST(LLVM, LetTest01) { } TEST(LLVM, LetTest02) { - KernelScope kernel_scope; - Placeholder a(BufHandle("A", {1}, kFloat)); std::vector v = {1, 0}; std::vector args({v.data()}); @@ -287,8 +266,6 @@ TEST(LLVM, LetTest02) { } TEST(LLVM, LetTestMultitype) { - KernelScope kernel_scope; - Placeholder a(BufHandle("A", {1}, kDouble)); std::vector v = {1, 0}; std::vector args({v.data()}); @@ -310,7 +287,6 @@ TEST(LLVM, LetTestMultitype) { } TEST(LLVM, BufferTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {32}, kFloat)); std::vector v(5); std::vector args({v.data()}); @@ -320,7 +296,6 @@ TEST(LLVM, BufferTest) { } TEST(LLVM, BlockTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {32}, kInt)); std::vector v = {1, 2}; std::vector args({v.data()}); @@ -338,7 +313,6 @@ TEST(LLVM, BlockTest) { } TEST(LLVM, LoadStoreTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); Placeholder b(BufHandle("B", {1}, kInt)); std::vector a_buffer = {42}; @@ -353,7 +327,6 @@ TEST(LLVM, LoadStoreTest) { } TEST(LLVM, IfThenElseTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); Placeholder b(BufHandle("B", {1}, kInt)); Placeholder c(BufHandle("C", {1}, kInt)); @@ -371,8 +344,6 @@ TEST(LLVM, IfThenElseTest) { // if (x < 10) x = x + 1 TEST(LLVM, CondNoFalseBlockTest) { - KernelScope kernel_scope; - Placeholder x(BufHandle("X", {1}, kInt)); auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); @@ -396,8 +367,6 @@ TEST(LLVM, CondNoFalseBlockTest) { // x = x - 1; // } TEST(LLVM, CondTest) { - KernelScope kernel_scope; - Placeholder x(BufHandle("X", {1}, kInt)); auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); auto cond = @@ -434,8 +403,6 @@ TEST(LLVM, CondTest) { // } // } TEST(LLVM, CondNestedTest) { - KernelScope kernel_scope; - Placeholder x(BufHandle("X", {1}, kInt)); auto true_cmp = CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); @@ -470,7 +437,6 @@ TEST(LLVM, CondNestedTest) { } TEST(LLVM, DirectVectorization) { - KernelScope ks; constexpr int M = 3; constexpr int N = 64; BufHandle a("a", {M, N}, kFloat); @@ -491,7 +457,6 @@ TEST(LLVM, DirectVectorization) { } TEST(LLVM, VecLoadStoreTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); Placeholder b(BufHandle("B", {1}, kInt)); std::vector a_buffer = {1, 1, 1, 1}; @@ -513,7 +478,6 @@ TEST(LLVM, VecLoadStoreTest) { #define FLOAT_INTRINSICS_TEST(Name, Lanes) \ TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ - KernelScope kernel_scope; \ Placeholder a(BufHandle("A", {1}, kFloat)); \ Placeholder b(BufHandle("B", {1}, kFloat)); \ float val = 0.5f; \ @@ -552,7 +516,6 @@ FLOAT_INTRINSICS_TEST(lgamma, 8) #define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ - KernelScope kernel_scope; \ Placeholder a(BufHandle("A", {1}, kDouble)); \ Placeholder b(BufHandle("B", {1}, kDouble)); \ float val = 0.5f; \ @@ -590,13 +553,12 @@ DOUBLE_INTRINSICS_TEST(lgamma, 4) #undef DOUBLE_INTRINSICS_TEST TEST(LLVM, VectorizerLoadStoreTest) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {1}, kInt)); - Tensor* c = + Tensor c = Compute("c", {{4, "i"}}, [&](const VarHandle& i) { return a.load(i); }); - Placeholder c_buf(BufHandle(c->buf())); + Placeholder c_buf(BufHandle(c.buf())); LoopNest l({c}); StmtPtr s = l.root_stmt(); ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); @@ -613,14 +575,13 @@ TEST(LLVM, VectorizerLoadStoreTest) { } TEST(LLVM, VectorizeBitCast) { - KernelScope kernel_scope; Placeholder a(BufHandle("A", {128}, kInt)); - Tensor* c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{128, "i"}}, [&](const VarHandle& i) { return bitcast(a.load(i)); }); - Placeholder c_buf(BufHandle(c->buf())); + Placeholder c_buf(BufHandle(c.buf())); LoopNest l({c}); StmtPtr s = l.root_stmt(); ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); @@ -639,7 +600,6 @@ TEST(LLVM, VectorizeBitCast) { } TEST(LLVM, MemcpyTest) { - KernelScope kernel_scope; constexpr int N = 32; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -661,7 +621,6 @@ TEST(LLVM, MemcpyTest) { } TEST(LLVM, BzeroTest) { - KernelScope kernel_scope; constexpr int N = 32; Placeholder b(BufHandle("B", {N}, kInt)); std::vector b_buffer(N, 11); @@ -679,7 +638,6 @@ TEST(LLVM, BzeroTest) { } TEST(LLVM, ElemwiseAdd) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -705,7 +663,6 @@ TEST(LLVM, ElemwiseAdd) { } TEST(LLVM, ElemwiseAddFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -731,7 +688,6 @@ TEST(LLVM, ElemwiseAddFloat) { } TEST(LLVM, ElemwiseLog10Float) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -758,7 +714,6 @@ TEST(LLVM, ElemwiseLog10Float) { } TEST(LLVM, ElemwiseLog1pFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -785,7 +740,6 @@ TEST(LLVM, ElemwiseLog1pFloat) { } TEST(LLVM, ElemwiseMaxInt) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -812,7 +766,6 @@ TEST(LLVM, ElemwiseMaxInt) { } TEST(LLVM, ElemwiseMinInt) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -839,7 +792,6 @@ TEST(LLVM, ElemwiseMinInt) { } TEST(LLVM, ElemwiseMaxFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -866,7 +818,6 @@ TEST(LLVM, ElemwiseMaxFloat) { } TEST(LLVM, ElemwiseMaxNaNFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -894,7 +845,6 @@ TEST(LLVM, ElemwiseMaxNaNFloat) { } TEST(LLVM, ElemwiseMinFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -921,7 +871,6 @@ TEST(LLVM, ElemwiseMinFloat) { } TEST(LLVM, ElemwiseMinNaNFloat) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -949,7 +898,6 @@ TEST(LLVM, ElemwiseMinNaNFloat) { } TEST(LLVM, ElemwiseMod) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -975,7 +923,6 @@ TEST(LLVM, ElemwiseMod) { } TEST(LLVM, CompareSelectIntEQ) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kInt)); Placeholder b(BufHandle("B", {N}, kInt)); @@ -1016,7 +963,6 @@ TEST(LLVM, CompareSelectIntEQ) { } TEST(LLVM, CompareSelectFloatEQ) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kFloat)); Placeholder b(BufHandle("B", {N}, kFloat)); @@ -1050,7 +996,6 @@ TEST(LLVM, CompareSelectFloatEQ) { } TEST(LLVM, CompareSelectByteGT) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kByte)); Placeholder b(BufHandle("B", {N}, kByte)); @@ -1091,7 +1036,6 @@ TEST(LLVM, CompareSelectByteGT) { } TEST(LLVM, CompareSelectByteGE) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kByte)); Placeholder b(BufHandle("B", {N}, kByte)); @@ -1127,7 +1071,6 @@ TEST(LLVM, CompareSelectByteGE) { } TEST(LLVM, CompareSelectByteLT) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kByte)); Placeholder b(BufHandle("B", {N}, kByte)); @@ -1168,7 +1111,6 @@ TEST(LLVM, CompareSelectByteLT) { } TEST(LLVM, CompareSelectByteLE) { - KernelScope kernel_scope; constexpr int N = 1024; Placeholder a(BufHandle("A", {N}, kByte)); Placeholder b(BufHandle("B", {N}, kByte)); @@ -1204,7 +1146,6 @@ TEST(LLVM, CompareSelectByteLE) { } TEST(LLVM, StoreFloat) { - KernelScope kernel_scope; Placeholder result(BufHandle("result", {1}, kFloat)); std::vector result_buffer = {0.0f}; auto expr = result.store({0}, FloatImm::make(3.14f)); @@ -1215,14 +1156,13 @@ TEST(LLVM, StoreFloat) { } TEST(LLVM, SimpleMath01) { - KernelScope kernel_scope; const int N = 1024; - Tensor* tensor = Compute("f", {{N, "i"}}, [](const VarHandle& i) { + Tensor tensor = Compute("f", {{N, "i"}}, [](const VarHandle& i) { return cast(i * i + 1); }); LoopNest l({tensor}); StmtPtr stmt = l.root_stmt(); - Placeholder f_buf(BufHandle(tensor->buf())); + Placeholder f_buf(BufHandle(tensor.buf())); LLVMCodeGen cg(stmt, {f_buf}); PaddedBuffer f_v(N, "f_v"); @@ -1237,15 +1177,14 @@ TEST(LLVM, SimpleMath01) { } TEST(LLVM, ComputeMul) { - KernelScope kernel_scope; const int N = 1024; Placeholder a(BufHandle("a", {N}, kFloat)); Placeholder b(BufHandle("b", {N}, kFloat)); - Tensor* c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); - Placeholder c_buf(BufHandle(c->buf())); + Placeholder c_buf(BufHandle(c.buf())); LoopNest l({c}); StmtPtr s = l.root_stmt(); @@ -1260,17 +1199,16 @@ TEST(LLVM, ComputeMul) { } TEST(LLVM, BroadcastAdd) { - KernelScope kernel_scope; const int M = 32; const int N = 1024; Placeholder a(BufHandle("a", {M, N}, kFloat)); Placeholder b(BufHandle("b", {N}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j) + b.load(j); }); - Placeholder c_buf(BufHandle(c->buf())); + Placeholder c_buf(BufHandle(c.buf())); LoopNest l({c}); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -1293,7 +1231,6 @@ TEST(LLVM, BroadcastAdd) { } TEST(LLVM, BitwiseOps) { - KernelScope kernel_scope; auto a = IntImm::make(59); auto b = IntImm::make(11); auto c = IntImm::make(101); @@ -1306,7 +1243,6 @@ TEST(LLVM, BitwiseOps) { } TEST(LLVM, ArithmeticRightShift) { - KernelScope ks; auto a = CharImm::make(-4); auto b = CharImm::make(1); ExprHandle f = a >> b; @@ -1315,7 +1251,6 @@ TEST(LLVM, ArithmeticRightShift) { } TEST(LLVM, LogicalRightShift) { - KernelScope ks; auto a = ByteImm::make(0xfc); auto b = ByteImm::make(1); ExprHandle f = a >> b; @@ -1324,7 +1259,6 @@ TEST(LLVM, LogicalRightShift) { } TEST(LLVM, DynamicShapeAdd) { - KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); @@ -1346,7 +1280,6 @@ TEST(LLVM, DynamicShapeAdd) { } TEST(LLVM, BindDynamicShapeAdd) { - KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); @@ -1367,12 +1300,11 @@ TEST(LLVM, BindDynamicShapeAdd) { } TEST(LLVM, TensorDynamicShapeAdd) { - KernelScope kernel_scope; auto testWithSize = [](int32_t size) { VarHandle n("n", kInt); Placeholder a(BufHandle("a", {n}, kFloat)); Placeholder b(BufHandle("b", {n}, kFloat)); - Tensor* c = Compute("c", {{n, "n"}}, [&](const VarHandle& i) { + Tensor c = Compute("c", {{n, "n"}}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); LoopNest l({c}); @@ -1390,13 +1322,12 @@ TEST(LLVM, TensorDynamicShapeAdd) { } TEST(LLVM, DynamicShape2D) { - KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); Placeholder a(BufHandle("a", {m, n}, kFloat)); Placeholder b(BufHandle("b", {m, n}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j) + b.load(i, j); }); @@ -1416,7 +1347,6 @@ TEST(LLVM, DynamicShape2D) { } TEST(LLVM, EmptyStmt) { - KernelScope kernel_scope; StmtPtr s = alloc(std::vector({})); LLVMCodeGen cg(s, {}); @@ -1425,10 +1355,9 @@ TEST(LLVM, EmptyStmt) { } TEST(LLVM, EliminatedStmt) { - KernelScope kernel_scope; Placeholder a(BufHandle("a", {1}, kFloat)); - Tensor* c = Compute("c", {{0, "m"}}, [&](const VarHandle& m) { return m; }); + Tensor c = Compute("c", {{0, "m"}}, [&](const VarHandle& m) { return m; }); LoopNest l({c}); l.prepareForCodegen(); @@ -1441,8 +1370,6 @@ TEST(LLVM, EliminatedStmt) { } TEST(LLVM, SimpleReduction) { - KernelScope kernel_scope; - int M = 128; int N = 64; const int kTotalSize = M * N; @@ -1452,7 +1379,7 @@ TEST(LLVM, SimpleReduction) { // TODO: why doesn't implicit vector work? std::vector axis = {DimArg(1)}; std::vector reduce_axis = {DimArg(M), DimArg(N)}; - Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis); + Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis); LoopNest loop({b}); loop.prepareForCodegen(); @@ -1480,8 +1407,6 @@ TEST(LLVM, SimpleReduction) { } TEST(LLVM, RFactorReduction) { - KernelScope kernel_scope; - int M = 128; int N = 64; const int kTotalSize = M * N; @@ -1491,7 +1416,7 @@ TEST(LLVM, RFactorReduction) { // TODO: why doesn't implicit vector work? std::vector axis = {DimArg(1)}; std::vector reduce_axis = {DimArg(M), DimArg(N)}; - Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis); + Tensor b = Reduce("sum", axis, Sum(), a, reduce_axis); LoopNest loop({b}); std::vector loops = loop.getLoopStmtsFor(b); @@ -1502,7 +1427,7 @@ TEST(LLVM, RFactorReduction) { loops = loop.getLoopStmtsFor(b); loop_m = loops.at(2); loop_n = loops.at(1); - auto b_body = loop.getAllWritesToBuf(b->buf())[1]; + auto b_body = loop.getAllWritesToBuf(b.buf())[1]; ASSERT_TRUE(loop.rfactor(b_body, loop_n)); loop.prepareForCodegen(); @@ -1530,21 +1455,19 @@ TEST(LLVM, RFactorReduction) { } TEST(LLVM, RFactorVectorizedReduction) { - KernelScope kernel_scope; - int M = 128; int N = 64; const int kTotalSize = M * N; Placeholder a("a", kFloat, {1, M, N}); - Tensor* b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}}); + Tensor b = Reduce("sum", {{1, "K"}}, Sum(), a, {{M, "M"}, {N, "N"}}); LoopNest loopnest({b}); std::vector loops = loopnest.getLoopStmtsFor(b); // Reorder n and m loops loopnest.reorderAxis(loops.at(1), loops.at(2)); - auto b_body = loopnest.getAllWritesToBuf(b->buf()).at(1); - auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b->buf()); + auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); + auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); @@ -1578,43 +1501,54 @@ TEST(LLVM, RFactorVectorizedReduction) { ExpectAllNear(b_v, b_ref, 1e-5); } -TEST(LLVM, SimpleParallel) { - for (int test_cfg = 0; test_cfg < 4; test_cfg++) { - // Compute a simple operation, and try all loop-axis combination to be - // parallel or sequential. - KernelScope kernel_scope; - const int M = 4; - const int N = 6; - Tensor* f = Compute( - "f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) { - return cast(m + n); - }); - LoopNest loop_nest({f}); - auto const& loops = loop_nest.getLoopStmtsFor(f); - ForPtr m = loops[0]; - ForPtr n = loops[1]; - if (test_cfg & 0x1) { - m->set_parallel(); - } - if (test_cfg & 0x2) { - n->set_parallel(); - } - loop_nest.prepareForCodegen(); - StmtPtr stmt = loop_nest.root_stmt(); - LLVMCodeGen cg(stmt, {f}); +template +static void testSimpleParallel() { + // Compute a simple operation, and try all loop-axis combination to be + // parallel or sequential. + const int M = 4; + const int N = 6; + Tensor f = Compute( + "f", {{M, "m"}, {N, "n"}}, [](const VarHandle& m, const VarHandle& n) { + return cast(m + n); + }); + LoopNest loop_nest({f}); + auto const& loops = loop_nest.getLoopStmtsFor(f); + ForPtr m = loops[0]; + ForPtr n = loops[1]; + if (outer) { + m->set_parallel(); + } + if (inner) { + n->set_parallel(); + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {f}); - PaddedBuffer f_v(M, N, "f_v"); - std::vector args({f_v.data()}); - int value = cg.value(args); - ASSERT_EQ(value, 0); - PaddedBuffer f_ref(M, N, "f_ref"); - for (int m = 0; m < M; m++) { - for (int n = 0; n < N; n++) { - f_ref(m, n) = m + n; - } + PaddedBuffer f_v(M, N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(M, N, "f_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + f_ref(m, n) = m + n; } - ExpectAllNear(f_v, f_ref, 1e-5); } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, SimpleParallelSS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelSP) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPP) { + testSimpleParallel(); } TEST(LLVM, CompositeParallel) { @@ -1623,24 +1557,23 @@ TEST(LLVM, CompositeParallel) { // Compute a composite operation, and try all loop-axis combination to be // parallel or sequential. for (int test_cfg = 0; test_cfg < test_count; test_cfg++) { - KernelScope kernel_scope; int M = 5; int N = 7; - Tensor* t1 = + Tensor t1 = Compute("t1", {{M, "M"}}, [](const VarHandle& m) { return m + 1.f; }); - Tensor* t2 = + Tensor t2 = Compute("t2", {{N, "N"}}, [](const VarHandle& n) { return n + 2.f; }); - Tensor* t3 = Compute( + Tensor t3 = Compute( "t3", {{M, "M"}, {N, "N"}}, [=](const VarHandle& m, const VarHandle& n) { - return t1->load(m) * t2->load(n); + return t1.load(m) * t2.load(n); }); - Tensor* t4 = Compute( + Tensor t4 = Compute( "t4", {{M, "M"}, {N, "N"}}, [=](const VarHandle& m, const VarHandle& n) { - return t3->load(m, n) + m + n; + return t3.load(m, n) + m + n; }); LoopNest loop_nest({t4}, {t1, t2, t3, t4}); std::vector loop_list; @@ -1687,15 +1620,13 @@ TEST(LLVM, CompositeParallel) { } TEST(LLVM, VectorizedGEMM) { - KernelScope ks; - int M = 32; int N = 32; int K = 48; Placeholder AP(BufHandle("A", {M, K}, kFloat)); Placeholder BP(BufHandle("B", {K, N}, kFloat)); - Tensor* CT = Reduce( + Tensor CT = Reduce( "gemm", {{M, "M"}, {N, "N"}}, Sum(), @@ -1771,12 +1702,11 @@ TEST(LLVM, VectorizedGEMM) { } TEST(LLVM, CallRaw) { - KernelScope kernel_scope; const int M = 32; VarHandle N("N", kInt); Placeholder a(BufHandle("a", {M, N}, kFloat)); Placeholder b(BufHandle("b", {N}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j) + b.load(j); }); @@ -1793,7 +1723,7 @@ TEST(LLVM, CallRaw) { std::vector cv(M * N_value, 0); std::vector args({av.data(), bv.data(), cv.data(), &N_value}); - LLVMCodeGen cg(s, {a, b, BufHandle(c->buf()), N}); + LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); cg.call_raw(args); for (int i = 0; i < M; i++) { @@ -1802,7 +1732,7 @@ TEST(LLVM, CallRaw) { } } - SimpleIREvaluator eval(s, {a, b, BufHandle(c->buf()), N}); + SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); eval.call_raw(args); for (int i = 0; i < M; i++) { @@ -1813,12 +1743,11 @@ TEST(LLVM, CallRaw) { } TEST(LLVM, CustomTarget) { - KernelScope kernel_scope; constexpr int M = 16; Placeholder a("a", kFloat, {M}); Placeholder b("b", kFloat, {M}); Placeholder c("c", kFloat, {M}); - Tensor* d = Compute("d", {{M, "m"}}, [&](const VarHandle& m) { + Tensor d = Compute("d", {{M, "m"}}, [&](const VarHandle& m) { return a.load(m) * b.load(m) + c.load(m); }); LoopNest nest({d}); diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 440b169d57259..b1d59a1dee066 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -29,23 +29,31 @@ void checkIR(StmtPtr s, const std::string& pattern) { torch::jit::testing::FileCheck().run(pattern, oss.str()); } +void checkExprIR(ExprPtr e, const std::string& pattern) { + std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; + std::ostringstream oss; + oss << *e << "\n"; + torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); +} + +void checkExprIR(const ExprHandle& e, const std::string& pattern) { + checkExprIR(e.node(), pattern); +} + TEST(LoopNest, ExprSimple01) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{16, "X"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::splitWithTail(loops[0], 2); LoopNest::splitWithTail(loops[0], 2); } TEST(LoopNest, ExprLower01) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{16, "x"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); @@ -58,14 +66,12 @@ TEST(LoopNest, ExprLower01) { } TEST(LoopNest, ExprSimple02) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x, const ExprHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }; - Tensor* tensor = Compute("f", {{26, "x"}, {5, "y"}}, func); + Tensor tensor = Compute("f", {{26, "x"}, {5, "y"}}, func); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::splitWithTail(loops[0], 4); @@ -153,18 +159,16 @@ void assertForRanges( } TEST(LoopNest, ExprSliceHeadWithLoopOptions) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); LoopNest::sliceHead(loops[0], 2, &head, &tail); @@ -178,18 +182,16 @@ TEST(LoopNest, ExprSliceHeadWithLoopOptions) { } TEST(LoopNest, ExprSliceTailWithLoopOptions) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -212,18 +214,16 @@ TEST(LoopNest, ExprSliceTailWithLoopOptions) { TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { // When factor equals the For loop's original size, keep using the original // For loop. - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 10, &head, &tail); ASSERT_EQ(head, loops[0]); @@ -234,18 +234,16 @@ TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { } TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 100, &head, &tail); ASSERT_EQ(head, loops[0]); @@ -256,38 +254,34 @@ TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { } TEST(LoopNest, ExprSliceHead) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceHead(loops[0], 4, &head, &tail); ASSERT_NE(head, nullptr); ASSERT_NE(head, loops[0]); ASSERT_NE(tail, nullptr); - ASSERT_NE(tail, loops[0]); + ASSERT_EQ(tail, loops[0]); BlockPtr body = getSimplifiedBody(l); assertForRanges(body, {{0, 4}, {4, 10}}); } TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; @@ -308,18 +302,16 @@ TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { // When factor equals the For loop's original size, keep using the original // For loop. - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 10, &head, &tail); ASSERT_EQ(head, nullptr); @@ -332,18 +324,16 @@ TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { // When factor equals the For loop's original size, keep using the original // For loop. - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 100, &head, &tail); ASSERT_EQ(head, nullptr); @@ -354,22 +344,20 @@ TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { } TEST(LoopNest, ExprSliceTail) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::sliceTail(loops[0], 4, &head, &tail); ASSERT_NE(head, nullptr); - ASSERT_NE(head, loops[0]); + ASSERT_EQ(head, loops[0]); ASSERT_NE(tail, nullptr); ASSERT_NE(tail, loops[0]); @@ -381,19 +369,17 @@ TEST(LoopNest, ExprSplitAndSlice) { // 0: splitWithTail // 1: sliceTail on inner loop // 2: sliceHead on outer loop - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{100, "x"}}, func); + Tensor tensor = Compute("f", {{100, "x"}}, func); LoopNest l({tensor}); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr inner; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr tail; - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // outer: [0, 4) // inner: [0, 21) // tail: [84, 100) @@ -435,14 +421,12 @@ TEST(LoopNest, ExprSplitAndSlice) { TEST(LoopNest, ExprSliceAndNormalize) { // 0: sliceHead // 1: normalize tail - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{10, "x"}}, func); + Tensor tensor = Compute("f", {{10, "x"}}, func); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; @@ -469,13 +453,12 @@ TEST(LoopNest, ExprSliceWithVariableDimension) { auto testWithDimension = [](int dimension, const std::vector>& expected_for_ranges) { - KernelScope kernel_scope; VarHandle dim("dim", kInt); - Tensor* tensor = + Tensor tensor = Compute("f", {{dim, "x"}}, [](const ExprHandle& x) { return x; }); LoopNest l({tensor}); std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head; @@ -506,14 +489,12 @@ TEST(LoopNest, ExprSliceWithVariableDimension) { } TEST(LoopNest, ExprSplitWithTail) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x) { return ExprHandle(1.0f) + cast(x); }; - Tensor* tensor = Compute("f", {{199, "x"}}, func); + Tensor tensor = Compute("f", {{199, "x"}}, func); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) LoopNest::splitWithTail(loops[0], 17); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) @@ -537,14 +518,12 @@ TEST(LoopNest, ExprSplitWithTail) { } TEST(LoopNest, ExprSplitWithTailNone) { - KernelScope kernel_scope; auto func = [](const ExprHandle& x, const ExprHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }; - Tensor* tensor = Compute("f", {{24, "x"}, {5, "y"}}, func); + Tensor tensor = Compute("f", {{24, "x"}, {5, "y"}}, func); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::splitWithTail(loops[0], 4); StmtPtr stmt = l.root_stmt(); @@ -596,19 +575,17 @@ TEST(LoopNest, ExprSplitWithTailNone) { } TEST(LoopNest, ExprSplitWithMask01) { - KernelScope kernel_scope; const int M = 26; const int N = 5; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {M, N}); - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; }); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::splitWithMask(loops[1], 4); StmtPtr stmt = l.root_stmt(); @@ -633,17 +610,15 @@ TEST(LoopNest, ExprSplitWithMask01) { // Tests the case where we split a loop cleanly multiple times, we should not // insert any masks. TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { - KernelScope kernel_scope; const int M = 64; Placeholder a_buf("a", kFloat, {M}); Placeholder b_buf("b", kFloat, {M}); - Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { + Tensor tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::splitWithMask(loops[0], 4); LoopNest::splitWithMask(loops[0], 4); @@ -661,8 +636,6 @@ TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { } TEST(LoopNest, getLoopAt) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; i++) { // for (int j = 0; j < 100; j++) { @@ -720,19 +693,17 @@ TEST(LoopNest, getLoopAt) { } TEST(LoopNest, TileSimple) { - KernelScope kernel_scope; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) const int M = 64, N = 64; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {M, N}); - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; }); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) l.tile(loops[0], loops[1], 4, 8); @@ -767,19 +738,17 @@ TEST(LoopNest, TileSimple) { } TEST(LoopNest, TileWithTails) { - KernelScope kernel_scope; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) const int M = 64, N = 64; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {M, N}); - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; }); LoopNest l({tensor}); - std::vector loops = - l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) l.tile(loops[0], loops[1], 5, 9); @@ -815,12 +784,11 @@ TEST(LoopNest, TileWithTails) { } TEST(LoopNest, TileInMiddle) { - KernelScope kernel_scope; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) const int M = 8, N = 8, L = 8, K = 8; Placeholder a_buf("a", kFloat, {M, N, L, K}); Placeholder b_buf("b", kFloat, {M, N, L, K}); - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{M, "m"}, {N, "n"}, {L, "l"}, {K, "k"}}, [&](const ExprHandle& m, @@ -832,7 +800,7 @@ TEST(LoopNest, TileInMiddle) { LoopNest nest({tensor}); std::vector loops = - nest.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) nest.tile(loops[1], loops[2], 3, 3); @@ -876,11 +844,10 @@ TEST(LoopNest, TileInMiddle) { } TEST(LoopNest, SplitWithTailWithLoopOptions) { - KernelScope kernel_scope; const int M = 21; Placeholder a_buf("a", kFloat, {M}); Placeholder b_buf("b", kFloat, {M}); - Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { + Tensor tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -907,11 +874,10 @@ TEST(LoopNest, SplitWithTailWithLoopOptions) { } TEST(LoopNest, SplitWithMaskWithLoopOptions) { - KernelScope kernel_scope; const int M = 21; Placeholder a_buf("a", kFloat, {M}); Placeholder b_buf("b", kFloat, {M}); - Tensor* tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { + Tensor tensor = Compute("f", {{M, "m"}}, [&](const ExprHandle& m) { return a_buf.load(m) + b_buf.load(m) + 1.0f; }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -932,13 +898,12 @@ TEST(LoopNest, SplitWithMaskWithLoopOptions) { } TEST(LoopNest, ScheduleBroadcastAddBuffer) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {N, K}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{M, "m"}, {N, "n"}, {K, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { @@ -981,23 +946,22 @@ TEST(LoopNest, ScheduleBroadcastAddBuffer) { } TEST(LoopNest, ScheduleFunctionCall01) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {N, K}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{M, "m"}, {N, "n"}, {K, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) + b_buf.load(n, k); }); - Tensor* d = Compute( + Tensor d = Compute( "d", {{M, "m"}, {N, "n"}, {K, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c->load(m, n, k) + 1; + return c.load(m, n, k) + 1; }); LoopNest l({d}, {c, d}); @@ -1038,7 +1002,6 @@ TEST(LoopNest, ScheduleFunctionCall01) { } TEST(LoopNest, ScheduleInlineSimple) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -1047,22 +1010,22 @@ TEST(LoopNest, ScheduleInlineSimple) { Placeholder c_buf("c", kFloat, {M, N}); Placeholder d_buf("d", kFloat, {M, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) * b_buf.load(n, k); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x->load(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); }); LoopNest l1({y}, {x, y}); LoopNest l2(l1); - l2.computeInline(x->buf()); + l2.computeInline(x.buf()); l1.prepareForCodegen(); l2.prepareForCodegen(); @@ -1119,7 +1082,6 @@ static std::string remove_space(const std::string& str) { } void InlineFunc01Helper(const std::vector& inline_order) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -1128,31 +1090,31 @@ void InlineFunc01Helper(const std::vector& inline_order) { Placeholder c_buf("c", kFloat, {M, N}); Placeholder d_buf("d", kFloat, {M, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) * b_buf.load(n, k); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x->load(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); }); - Tensor* z = Compute( + Tensor z = Compute( "z", {{M, "m3"}, {N, "n3"}, {K, "k3"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m, n, k) + y->load(m, n, k); + return x.load(m, n, k) + y.load(m, n, k); }); LoopNest l({z}, {x, y, z}); for (const std::string& order : inline_order) { if (order == "x") { - l.computeInline(x->buf()); + l.computeInline(x.buf()); } else if (order == "y") { - l.computeInline(y->buf()); + l.computeInline(y.buf()); } else { throw std::runtime_error("Invalid order: " + order); } @@ -1207,7 +1169,7 @@ void InlineFunc01Helper(const std::vector& inline_order) { } if (inline_order.size() == 2) { - Tensor* z2 = Compute( + Tensor z2 = Compute( "z", {{M, "m3"}, {N, "n3"}, {K, "k3"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { @@ -1238,26 +1200,25 @@ TEST(LoopNest, ScheduleInlineFunc01) { // Make sure we cache random vars if we should. TEST(LoopNest, ScheduleInlineRandom) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return Mod::make(Intrinsics::make(kRand, kInt), 5); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m, n, k) + x->load(m, n, k); + return x.load(m, n, k) + x.load(m, n, k); }); LoopNest l1({y}, {x, y}); - l1.computeInline(x->buf()); + l1.computeInline(x.buf()); // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. @@ -1274,27 +1235,26 @@ TEST(LoopNest, ScheduleInlineRandom) { // Make sure we don't cache random vars that are not being inlined. TEST(LoopNest, ScheduleInlineRandomUnrelated) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return m * n * k; }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m, n, k) + Intrinsics::make(kRand, kInt) + + return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); }); LoopNest l1({y}, {x, y}); - l1.computeInline(x->buf()); + l1.computeInline(x.buf()); // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. @@ -1305,29 +1265,28 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) { # CHECK: for (int m2 = 0; m2 < 4; m2++) # CHECK: for (int n2 = 0; n2 < 5; n2++) # CHECK: for (int k2 = 0; k2 < 6; k2++) -# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR"); +# CHECK: y[m2, n2, k2] = ((k2 * m2) * n2 + (rand())) + (rand());)IR"); } // Make sure we generate the right number of random values == the dimensionality // of the production tensor. TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Tensor* x = Compute("x", {{M, "m1"}}, [&](const VarHandle& m) { + Tensor x = Compute("x", {{M, "m1"}}, [&](const VarHandle& m) { return Mod::make(Intrinsics::make(kRand, kInt), 5); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m) + x->load(m); + return x.load(m) + x.load(m); }); LoopNest l1({y}, {x, y}); - l1.computeInline(x->buf()); + l1.computeInline(x.buf()); // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. @@ -1344,24 +1303,23 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { // Make sure we don't screw up intrinsics thinking they're rand. TEST(LoopNest, ScheduleInlineIntrinsics) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; Placeholder a_buf("a", kFloat, {M, N}); Placeholder b_buf("b", kFloat, {N, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) * b_buf.load(n, k); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x->load(m, n, k)); + return Intrinsics::make(kSqrt, x.load(m, n, k)); }); PaddedBuffer a_v(M, N); @@ -1380,7 +1338,7 @@ TEST(LoopNest, ScheduleInlineIntrinsics) { LoopNest l1({y}, {x, y}); LoopNest l2(l1); - l2.computeInline(x->buf()); + l2.computeInline(x.buf()); l1.prepareForCodegen(); l2.prepareForCodegen(); @@ -1405,26 +1363,25 @@ TEST(LoopNest, ScheduleInlineIntrinsics) { // Make sure we can handle rand and non-rand intrinsics. TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return Intrinsics::make(kRand, kFloat); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return Intrinsics::make(kSqrt, x->load(m, n, k)); + return Intrinsics::make(kSqrt, x.load(m, n, k)); }); LoopNest l1({y}, {x, y}); - l1.computeInline(x->buf()); + l1.computeInline(x.buf()); StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); @@ -1439,32 +1396,30 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { // Split a Compute then inline it into another compute. TEST(LoopNest, ScheduleSplitAThenInline) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); LoopNest::splitWithMask(loops[0], 4); - ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); + ASSERT_THROWS_WITH(l.computeInline(a.buf()), "compound indices"); } // Split a Compute then inline another Compute into it. TEST(LoopNest, ScheduleSplitBThenInline) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); LoopNest::splitWithMask(loops[0], 3); - l.computeInline(a->buf()); + l.computeInline(a.buf()); l.prepareForCodegen(); StmtPtr s = IRSimplifier::simplify(l.root_stmt()); @@ -1479,33 +1434,31 @@ TEST(LoopNest, ScheduleSplitBThenInline) { // Split a Compute twice then inline it. TEST(LoopNest, ScheduleSplitTwiceThenInline) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr i_inner; LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); LoopNest::splitWithMask(loops[0], 4, &i_inner); LoopNest::splitWithMask(i_inner, 2); - ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); + ASSERT_THROWS_WITH(l.computeInline(a.buf()), "compound indices"); } // Inline a Compute, then split. TEST(LoopNest, ScheduleInlineThenSplit) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); LoopNest l({b}, {a, b}); - l.computeInline(a->buf()); + l.computeInline(a.buf()); std::vector loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.back(), 3); @@ -1522,17 +1475,16 @@ TEST(LoopNest, ScheduleInlineThenSplit) { // Split a Compute, inline it, then split the result. TEST(LoopNest, ScheduleSplitInlineThenSplit) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{16, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{16, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); LoopNest l({b}, {a, b}); auto loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.back(), 2); - l.computeInline(a->buf()); + l.computeInline(a.buf()); loops = NodeFinder::find(l.root_stmt()); LoopNest::splitWithMask(loops.front(), 2); @@ -1549,36 +1501,34 @@ TEST(LoopNest, ScheduleSplitInlineThenSplit) { // Oversplit a loop that is simplified out after inlining. TEST(LoopNest, ScheduleSplitInlineSimplify) { - KernelScope kernel_scope; - Tensor* a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return ExprHandle(4) * i - ExprHandle(2) * i; }); - Tensor* b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { - return a->load(j) - ExprHandle(1); + Tensor b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { + return a.load(j) - ExprHandle(1); }); LoopNest l({b}, {a, b}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); LoopNest::splitWithMask(loops[0], 4); - ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); + ASSERT_THROWS_WITH(l.computeInline(a.buf()), "compound indices"); } // Inline a Compute with two consumers. TEST(LoopNest, ScheduleInlineThreeMixedOnce) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{4, "k"}, {3, "l"}}, [&](const VarHandle& k, const VarHandle& l) { - return a->load(k) * b->load(l); + return a.load(k) * b.load(l); }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); - l.computeInline(a->buf()); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); l.prepareForCodegen(); StmtPtr s = IRSimplifier::simplify(l.root_stmt()); @@ -1595,21 +1545,20 @@ TEST(LoopNest, ScheduleInlineThreeMixedOnce) { // Inline Compute A into B, then inline B into C. TEST(LoopNest, ScheduleInlineThreeMixedTwice) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{4, "k"}, {3, "l"}}, [&](const VarHandle& k, const VarHandle& l) { - return a->load(k) * b->load(l); + return a.load(k) * b.load(l); }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); - l.computeInline(a->buf()); - l.computeInline(b->buf()); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); + l.computeInline(b.buf()); l.prepareForCodegen(); StmtPtr s = IRSimplifier::simplify(l.root_stmt()); @@ -1626,20 +1575,19 @@ TEST(LoopNest, ScheduleInlineThreeMixedTwice) { // Inline a Compute that is both a producer and consumer. TEST(LoopNest, ScheduleInlineThreeMixedInner) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{4, "k"}, {3, "l"}}, [&](const VarHandle& k, const VarHandle& l) { - return a->load(k) * b->load(l); + return a.load(k) * b.load(l); }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); - l.computeInline(b->buf()); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(b.buf()); l.prepareForCodegen(); StmtPtr s = IRSimplifier::simplify(l.root_stmt()); @@ -1656,50 +1604,48 @@ TEST(LoopNest, ScheduleInlineThreeMixedInner) { // Split 3 Computes, then inline the first two into the last. TEST(LoopNest, ScheduleInlineThreeMixedSplit) { - KernelScope kernel_scope; - Tensor* a = + Tensor a = Compute("a", {{18, "i"}}, [&](const VarHandle& i) { return i * i; }); - Tensor* b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { - return a->load(j + ExprHandle(8)); + Tensor b = Compute("b", {{6, "j"}}, [&](const VarHandle& j) { + return a.load(j + ExprHandle(8)); }); - Tensor* c = Compute( + Tensor c = Compute( "c", {{4, "k"}, {3, "l"}}, [&](const VarHandle& k, const VarHandle& l) { - return a->load(k) * b->load(l); + return a.load(k) * b.load(l); }); LoopNest l({c}, {a, b, c}); - std::vector loops = l.getAllLoopNestsWritingToBuf(a->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); LoopNest::splitWithMask(loops[0], 4); - loops = l.getAllLoopNestsWritingToBuf(b->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); LoopNest::splitWithMask(loops[0], 3); - loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::splitWithMask(loops[0], 2); - ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices"); + ASSERT_THROWS_WITH(l.computeInline(a.buf()), "compound indices"); } // Check that inlining works for output tensors too TEST(LoopNest, ScheduleInlineOutputTensors) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return m * n * k; }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m, n, k) + m; + return x.load(m, n, k) + m; }); LoopNest l1({x, y}); - l1.computeInline(x->buf()); + l1.computeInline(x.buf()); // would normally compare results but Rand isn't implemented in the // SimpleIREvaluator, even if we could seed it. @@ -1710,29 +1656,28 @@ TEST(LoopNest, ScheduleInlineOutputTensors) { # CHECK: for (int m1 = 0; m1 < 4; m1++) # CHECK: for (int n1 = 0; n1 < 5; n1++) # CHECK: for (int k1 = 0; k1 < 6; k1++) -# CHECK: x[m1, n1, k1] = (n1 * m1) * k1; +# CHECK: x[m1, n1, k1] = (k1 * m1) * n1; # CHECK: for (int m2 = 0; m2 < 4; m2++) # CHECK: for (int n2 = 0; n2 < 5; n2++) # CHECK: for (int k2 = 0; k2 < 6; k2++) -# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR"); +# CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR"); } TEST(LoopNest, ScheduleFuserStyle) { - KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Tensor* b = Compute( + Tensor b = Compute( "f", {{kTotalSize, "i"}}, [&](const std::vector& axes) { return a_buf.load(axes[0]) + 11.0f; }); - Tensor* c = Compute( + Tensor c = Compute( "g", {{kTotalSize, "i"}}, [&](const std::vector& axes) { - return b->load(axes[0]) + 1.0f; + return b.load(axes[0]) + 1.0f; }); LoopNest l({b, c}); @@ -1751,7 +1696,6 @@ TEST(LoopNest, ScheduleFuserStyle) { } TEST(LoopNest, ScheduleFuserThreeArg) { - KernelScope kernel_scope; const int kVectorSize = 8; const int kVectorCount = 128; const int kTotalSize = kVectorSize * kVectorCount; @@ -1761,14 +1705,14 @@ TEST(LoopNest, ScheduleFuserThreeArg) { Placeholder c(BufHandle("C", {ExprHandle(kTotalSize)}, kFloat)); Placeholder d(BufHandle("D", {ExprHandle(kTotalSize)}, kFloat)); - Tensor* e = Compute("e", {{kTotalSize, "i"}}, [&](const VarHandle& i) { + Tensor e = Compute("e", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); - Tensor* f = Compute("f", {{kTotalSize, "i"}}, [&](const VarHandle& i) { - return e->load(i) + c.load(i); + Tensor f = Compute("f", {{kTotalSize, "i"}}, [&](const VarHandle& i) { + return e.load(i) + c.load(i); }); - Tensor* g = Compute("g", {{kTotalSize, "i"}}, [&](const VarHandle& i) { - return f->load(i) + d.load(i); + Tensor g = Compute("g", {{kTotalSize, "i"}}, [&](const VarHandle& i) { + return f.load(i) + d.load(i); }); LoopNest l({g}, {e, f, g}); @@ -1790,13 +1734,12 @@ TEST(LoopNest, ScheduleFuserThreeArg) { } TEST(LoopNest, ScheduleDynamicShape2D) { - KernelScope kernel_scope; auto testWithSize = [](int32_t M, int32_t N) { VarHandle m("m", kInt); VarHandle n("n", kInt); Placeholder a(BufHandle("a", {m, n}, kFloat)); Placeholder b(BufHandle("b", {m, n}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j) + b.load(i, j); }); @@ -1829,14 +1772,13 @@ TEST(LoopNest, LoopNestComputeAt_1) { // should be in that loop after the transformation. Also, computation of A // should not be inlined into B. Instead, it should be computed into the temp, // and the temp should be used in B. - KernelScope kernel_scope; VarHandle N("N", kInt); - Tensor* A = Compute( + Tensor A = Compute( "A", {{N, "i_a"}}, [&](const VarHandle& i_a) { return i_a * i_a; }); - Tensor* B = Compute( - "B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A->load(i_b); }); + Tensor B = Compute( + "B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A.load(i_b); }); LoopNest l({B}, {A, B}); - std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -1875,21 +1817,20 @@ TEST(LoopNest, LoopNestComputeAt_2) { // p[cy,cx+1] + p[cy+1,cx+1] // } // } - KernelScope kernel_scope; const int kW = 16, kH = 16; VarHandle W("W", kInt); VarHandle H("H", kInt); - Tensor* p = Compute( + Tensor p = Compute( "prod", {{H + 1, "py"}, {W + 1, "px"}}, [&](const VarHandle& py, const VarHandle& px) { return px * py; }); - Tensor* c = Compute( + Tensor c = Compute( "cons", {{H, "cy"}, {W, "cx"}}, [&](const VarHandle& y, const VarHandle& x) { - return p->load(y, x) + p->load(y + 1, x) + p->load(y, x + 1) + - p->load(y + 1, x + 1); + return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + + p.load(y + 1, x + 1); }); std::vector c_ref(kW * kH, 0); @@ -1903,7 +1844,7 @@ TEST(LoopNest, LoopNestComputeAt_2) { { // First let's try to compute P at axis cy (the outer loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -1929,7 +1870,7 @@ TEST(LoopNest, LoopNestComputeAt_2) { { // Now let's try to compute P at axis cx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -1963,32 +1904,29 @@ TEST(LoopNest, LoopNestComputeAt_3) { // D(x,y) = A(x, y+1) + C(x, y) // // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. - KernelScope kernel_scope; const int kW = 16, kH = 16; VarHandle W("W", kInt); VarHandle H("H", kInt); - Tensor* A = Compute( + Tensor A = Compute( "A", {{H + 1, "ay"}, {W + 1, "ax"}}, [&](const VarHandle& ay, const VarHandle& ax) { return ax * ay; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{H + 1, "by"}, {W + 1, "bx"}}, - [&](const VarHandle& by, const VarHandle& bx) { - return A->load(by, bx); - }); - Tensor* C = Compute( + [&](const VarHandle& by, const VarHandle& bx) { return A.load(by, bx); }); + Tensor C = Compute( "C", {{H, "cy"}, {W, "cx"}}, [&](const VarHandle& cy, const VarHandle& cx) { - return B->load(cy, cx + 1); + return B.load(cy, cx + 1); }); - Tensor* D = Compute( + Tensor D = Compute( "D", {{H, "dy"}, {W, "dx"}}, [&](const VarHandle& dy, const VarHandle& dx) { - return A->load(dy + 1, dx) + C->load(dy, dx); + return A.load(dy + 1, dx) + C.load(dy, dx); }); std::vector c_ref(kW * kH, 0); @@ -2002,7 +1940,7 @@ TEST(LoopNest, LoopNestComputeAt_3) { { // First let's try to compute A at axis dy (the outer loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -2033,7 +1971,7 @@ TEST(LoopNest, LoopNestComputeAt_3) { { // Now let's try to compute A at axis dx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(D->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); l.prepareForCodegen(); StmtPtr s = l.root_stmt(); @@ -2066,21 +2004,19 @@ TEST(LoopNest, LoopNestComputeAt_3) { using Axis = const VarHandle&; TEST(LoopNest, Reduce2dComputeAt) { - KernelScope kernel_scope; - const int kW = 16, kH = 16; VarHandle W("W", kInt); VarHandle H("H", kInt); - Tensor* p = + Tensor p = Compute("prod", {{H + 1, "py"}, {W + 1, "px"}}, [&](Axis py, Axis px) { return px * py; }); - Tensor* c = Reduce( + Tensor c = Reduce( "cons", {{H, "cy"}, {W, "cx"}}, Sum(), - [&](Axis y, Axis x, Axis r, Axis s) { return p->load(y + r, x + s); }, + [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, {{2, "r"}, {2, "s"}}); std::vector c_ref(kW * kH, 0); @@ -2111,7 +2047,7 @@ TEST(LoopNest, Reduce2dComputeAt) { { // First let's try to compute P at axis cy (the outer loop) LoopNest l(orig_loopnest); - auto loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); // FIXME: Calling simplify here breaks the IR: // MALFORMED INPUT: could not find base node in Load - temp[...] @@ -2130,7 +2066,7 @@ TEST(LoopNest, Reduce2dComputeAt) { # CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = int(0); # CHECK: for (int r = 0; r < 2; r++) { # CHECK: for (int s = 0; s < 2; s++) { -# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (s + cx) * 1]); +# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (cx + s) * 1]); # CHECK: } # CHECK: } # CHECK: } @@ -2148,7 +2084,7 @@ TEST(LoopNest, Reduce2dComputeAt) { { // Now let's try to compute P at axis cx (the inner loop) LoopNest l(orig_loopnest); - std::vector loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); l.simplify(); l.eliminateDeadStores(); @@ -2186,7 +2122,6 @@ TEST(LoopNest, DISABLED_Conv1d_NH) { // Lots of stuff is broken here. The computeAt swaps the axes for some odd // reason. Even without that, the index flattener fails due to "dimensions // mismatch in flatten index". - KernelScope kernel_scope; int N = 4; int H = 256; @@ -2194,17 +2129,17 @@ TEST(LoopNest, DISABLED_Conv1d_NH) { int Pad = 1; Placeholder IP("input", kFloat, {H}); - Tensor* A = + Tensor A = Compute("A", {{N, "np"}, {H + 2 * Pad, "hp"}}, [&](Axis n, Axis h) { auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); }); - Tensor* B = Reduce( + Tensor B = Reduce( "B", {{N, "n"}, {H, "h"}}, Sum(), - [&](Axis n, Axis h, Axis r) { return A->load(n, h + r); }, + [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, {{R, "r"}}); LoopNest l({B}); checkIR(l.root_stmt(), R"IR( @@ -2222,7 +2157,7 @@ TEST(LoopNest, DISABLED_Conv1d_NH) { # CHECK: } # CHECK: } )IR"); - std::vector loops = l.getAllLoopNestsWritingToBuf(B->buf()).at(0); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); // FIXME: The current IR is totally broken. The body of the inlined loop is: @@ -2280,8 +2215,7 @@ class LoopOrderHelper : public IRVisitor { }; TEST(LoopNest, LoopNestReorderAxis1) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{2, "x"}, {3, "y"}}, [](const VarHandle& x, const VarHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); @@ -2292,7 +2226,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { SimpleIREvaluator cg(stmt1, {tensor}); cg.call({stmt1_output}); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[0], loops[1]); StmtPtr stmt2 = Stmt::clone(l.root_stmt()); @@ -2313,7 +2247,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { } // Reorder them back. - loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[0], loops[1]); StmtPtr stmt3 = l.root_stmt(); @@ -2329,8 +2263,7 @@ TEST(LoopNest, LoopNestReorderAxis1) { } TEST(LoopNest, LoopNestReorderPartialAxes) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{2, "x"}, {3, "y"}, {4, "z"}}, [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { @@ -2347,7 +2280,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { SimpleIREvaluator cg(stmt1, {tensor}); cg.call({stmt1_output}); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[0], loops[1]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,x,z,"); @@ -2361,7 +2294,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { ASSERT_EQ(stmt1_output[i], stmt2_output[i]); } - loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[1], loops[2]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "y,z,x,"); @@ -2377,8 +2310,7 @@ TEST(LoopNest, LoopNestReorderPartialAxes) { } TEST(LoopNest, LoopNestReorderInternalAxis) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{1, "w"}, {2, "x"}, {3, "y"}, {4, "z"}}, [](const VarHandle& w, @@ -2398,7 +2330,7 @@ TEST(LoopNest, LoopNestReorderInternalAxis) { SimpleIREvaluator cg(stmt1, {tensor}); cg.call({stmt1_output}); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[2], loops[1]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "w,y,x,z,"); @@ -2414,8 +2346,7 @@ TEST(LoopNest, LoopNestReorderInternalAxis) { } TEST(LoopNest, LoopNestReorderEnclosingAxis) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{1, "w"}, {2, "x"}, {3, "y"}, {4, "z"}}, [](const VarHandle& w, @@ -2434,7 +2365,7 @@ TEST(LoopNest, LoopNestReorderEnclosingAxis) { SimpleIREvaluator cg(stmt1, {tensor}); cg.call({stmt1_output}); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[0], loops[3]); ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "z,x,y,w,"); @@ -2450,15 +2381,14 @@ TEST(LoopNest, LoopNestReorderEnclosingAxis) { } TEST(LoopNest, LoopNestReorderSameAxis) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{2, "x"}, {3, "y"}}, [](const VarHandle& x, const VarHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); StmtPtr stmt1 = Stmt::clone(l.root_stmt()); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[1], loops[1]); StmtPtr stmt2 = Stmt::clone(l.root_stmt()); @@ -2479,9 +2409,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { * Stmt 4 */ - KernelScope kernel_scope; - - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{2, "x"}, {3, "y"}, {4, "z"}}, [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { @@ -2492,7 +2420,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { Placeholder extra(BufHandle("res", {6, 3}, kFloat)); - auto loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); VarHandle i = VarHandle(loops[0]->var()); @@ -2578,7 +2506,7 @@ TEST(LoopNest, LoopNestReorderExtraStatements) { * * */ - loops = l.getAllLoopNestsWritingToBuf(tensor->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); LoopNest::reorderAxis(loops[0], loops[2]); StmtPtr stmt3 = Stmt::clone(l.root_stmt()); @@ -2615,9 +2543,7 @@ void LoopNestReorderTestHelper( bool append, int index1, int index2) { - KernelScope kernel_scope; - - Tensor* c = Compute( + Tensor c = Compute( "5d", {{2, "a"}, {3, "b"}, {2, "c"}, {3, "d"}, {2, "e"}}, [](const std::vector&) { return -1; }); @@ -2625,7 +2551,7 @@ void LoopNestReorderTestHelper( Placeholder extra(BufHandle("extra", {5}, kInt)); - auto loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); int j = 0; for (auto l : loops) { // Add an increment at each layer of the loop which counts the number of @@ -2666,7 +2592,7 @@ void LoopNestReorderTestHelper( ASSERT_EQ(extra1[i], expected_loops); } - loops = l.getAllLoopNestsWritingToBuf(c->buf()).at(0); + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); LoopNest::reorderAxis(loops[index1], loops[index2]); StmtPtr stmt2 = Stmt::clone(l.root_stmt()); @@ -2732,7 +2658,6 @@ TEST(LoopNest, LoopNestReorderLongStringFull) { } TEST(LoopNest, LoopNestReorderInternalLoopNest) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -2741,23 +2666,23 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) { Placeholder c_buf("c", kFloat, {M, N}); Placeholder d_buf("d", kFloat, {M, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) * b_buf.load(n, k); }); - Tensor* y = Compute( + Tensor y = Compute( "y", {{M, "m2"}, {N, "n2"}, {K, "k2"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c_buf.load(m, n) * d_buf.load(m, k) + x->load(m, n, k); + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); }); - Tensor* z = Compute( + Tensor z = Compute( "z", {{M, "m3"}, {N, "n3"}, {K, "k3"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return x->load(m, n, k) + y->load(m, n, k); + return x.load(m, n, k) + y.load(m, n, k); }); LoopNest l({z}, {x, y, z}); @@ -2833,15 +2758,14 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) { } TEST(LoopNest, OuterLoopVectorization) { - KernelScope kernel_scope; - Tensor* tensor = Compute( + Tensor tensor = Compute( "f", {{8, "X"}, {8, "y"}}, [](const VarHandle& x, const VarHandle& y) { return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; }); LoopNest l({tensor}); ASSERT_TRUE( - LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor->buf())[0][0])); + LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); StmtPtr root_stmt = l.root_stmt(); BlockPtr outer_block = to(root_stmt); @@ -2861,8 +2785,6 @@ TEST(LoopNest, OuterLoopVectorization) { } TEST(LoopNest, VectorizeLoopNotNormalized) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // for (int j = 1; j < 5; j++) { @@ -2886,12 +2808,11 @@ TEST(LoopNest, VectorizeLoopNotNormalized) { namespace { std::string constantUpperBoundLoopIR(int upper_bound_val) { - KernelScope kernel_scope; ExprHandle upper_bound(upper_bound_val); - Tensor* A = Compute( + Tensor A = Compute( "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; LoopNest::unroll(loops[0], &unrolled); std::ostringstream oss; @@ -2913,15 +2834,14 @@ TEST(LoopNest, Unroll) { } TEST(LoopNest, UnrollOuter) { - KernelScope kernel_scope; ExprHandle outer_bound(3); ExprHandle inner_bound(4); - Tensor* A = Compute( + Tensor A = Compute( "A", {{outer_bound, "x"}, {inner_bound, "y"}}, [&](const VarHandle& x, const VarHandle& y) { return x + y; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; LoopNest::unroll(loops[0], &unrolled); checkIR(unrolled, R"IR( @@ -2937,15 +2857,14 @@ TEST(LoopNest, UnrollOuter) { } TEST(LoopNest, UnrollInner) { - KernelScope kernel_scope; ExprHandle outer_bound(3); ExprHandle inner_bound(4); - Tensor* A = Compute( + Tensor A = Compute( "A", {{outer_bound, "x"}, {inner_bound, "y"}}, [&](const VarHandle& x, const VarHandle& y) { return x + y; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; LoopNest::unroll( static_to(loops[0]->body()->stmts().front()), &unrolled); @@ -2959,7 +2878,6 @@ TEST(LoopNest, UnrollInner) { } TEST(LoopNest, UnrollMultipleStatements) { - KernelScope kernel_scope; const int kTotalSize = 3; BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); @@ -2972,7 +2890,7 @@ TEST(LoopNest, UnrollMultipleStatements) { Block::make( {Store::make(a_buf, {x}, x * 2), Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); - Block::make({f}); + auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); checkIR(unrolled, R"IR( @@ -2985,8 +2903,6 @@ TEST(LoopNest, UnrollMultipleStatements) { } TEST(LoopNest, UnrollNonLiteralConstantBounds) { - KernelScope kernel_scope; - // Input IR: // for (int i = 2 - 1; i < 12 / 3; i++) { // for (int j = 0; j < 4; j++) { @@ -3031,19 +2947,17 @@ TEST(LoopNest, UnrollEmpty) { } TEST(LoopNest, NoUnroll) { - KernelScope kernel_scope; VarHandle upper_bound("N", kInt); - Tensor* A = Compute( + Tensor A = Compute( "A", {{upper_bound, "x"}}, [&](const VarHandle& x) { return x * 2; }); LoopNest l({A}); - std::vector loops = l.getAllLoopNestsWritingToBuf(A->buf())[0]; + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; StmtPtr unrolled = nullptr; ASSERT_THROWS_WITH( LoopNest::unroll(loops[0], &unrolled), "non-constant loop"); } TEST(LoopNest, UnrollWithLet) { - KernelScope kernel_scope; const int kTotalSize = 3; BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); @@ -3058,7 +2972,7 @@ TEST(LoopNest, UnrollWithLet) { {Let::make(e, 7), Store::make(a_buf, {x}, e), Store::make(b_buf, {x}, e + 1)})); - Block::make({f}); + auto parent_block = Block::make({f}); StmtPtr unrolled = nullptr; LoopNest::unroll(f, &unrolled); std::ostringstream oss; @@ -3086,8 +3000,6 @@ TEST(LoopNest, UnrollWithLet) { } TEST(LoopNest, IsNormalized) { - KernelScope kernel_scope; - // Input IR: // for (int i = 50; i < 100; i++) { // A[i] = B[i]; @@ -3110,8 +3022,6 @@ TEST(LoopNest, IsNormalized) { } TEST(LoopNest, NormalizeStartPositive) { - KernelScope kernel_scope; - // Input IR: // for (int x = 50; x < 100; x++) { // A[x] = B[x]; @@ -3142,8 +3052,6 @@ TEST(LoopNest, NormalizeStartPositive) { } TEST(LoopNest, NormalizeStartNegative) { - KernelScope kernel_scope; - // Input IR: // for (int x = -50; x < 100; x++) { // A[x + 50] = B[x + 50]; @@ -3174,8 +3082,6 @@ TEST(LoopNest, NormalizeStartNegative) { } TEST(LoopNest, NormalizeStartZero) { - KernelScope kernel_scope; - // Input IR: // for (int x = 0; x < 100; x++) { // A[x] = B[x]; @@ -3208,8 +3114,6 @@ TEST(LoopNest, NormalizeStartZero) { } TEST(LoopNest, NormalizeStartVariable) { - KernelScope kernel_scope; - // Input IR: // for (int x = y; x < 100; x++) { // A[x] = B[x]; @@ -3225,7 +3129,7 @@ TEST(LoopNest, NormalizeStartVariable) { {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), Store::make(b_buf, {x}, x * 2)}); auto for_stmt = For::make(x, y, 100, for_body); - Block::make({for_stmt}); + auto parent_block = Block::make({for_stmt}); LoopNest::normalize(for_stmt); @@ -3235,15 +3139,13 @@ TEST(LoopNest, NormalizeStartVariable) { const std::string& expected_ir = R"IR( # CHECK: for (int x = 0; x < 100 - y; x++) { - # CHECK: A[y + x] = B[y + x]; - # CHECK: B[y + x] = 2 * (y + x); + # CHECK: A[x + y] = B[x + y]; + # CHECK: B[x + y] = 2 * (x + y); )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } TEST(LoopNest, NormalizeOnNestedOuterLoop) { - KernelScope kernel_scope; - // Input IR: // for (int x = 50; x < 100; x++) { // for (int y = 10; y < 100; y++) { @@ -3276,8 +3178,6 @@ TEST(LoopNest, NormalizeOnNestedOuterLoop) { } TEST(LoopNest, NormalizeOnNestedInnerLoop) { - KernelScope kernel_scope; - // Input IR: // for (int x = 50; x < 100; x++) { // for (int y = 10; y < 100; y++) { @@ -3304,18 +3204,16 @@ TEST(LoopNest, NormalizeOnNestedInnerLoop) { R"IR( # CHECK: for (int x = 50; x < 100; x++) { # CHECK: for (int y = 0; y < 90; y++) { - # CHECK: A[x] = (((B[y + 10]) + 2 * y) + (A[x])) + 20; + # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); } TEST(LoopNest, NormalizeAndSplitWithTail) { - KernelScope kernel_scope; - // Create a dummy tensor to construct LoopNest. ExprHandle n(100); Placeholder a(BufHandle("a", {n}, kFloat)); - Tensor* b = + Tensor b = Compute("b", {{n, "i"}}, [&](const VarHandle& i) { return a.load(i); }); LoopNest l({b}); @@ -3327,7 +3225,7 @@ TEST(LoopNest, NormalizeAndSplitWithTail) { BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); VarHandle x("x", kInt); auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); - Block::make({for_stmt}); + auto parent_block = Block::make({for_stmt}); LoopNest::normalize(for_stmt); @@ -3359,8 +3257,6 @@ TEST(LoopNest, NormalizeAndSplitWithTail) { } TEST(LoopNest, FlattenSimpleLoopNest2D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // for (int j = 0; j < 5; j++) { @@ -3373,7 +3269,7 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) { auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); auto inner_for = For::make(j, 0, 5, for_body); auto outer_for = For::make(i, 0, 10, inner_for); - Block::make({outer_for}); + auto parent_block = Block::make({outer_for}); std::vector loops = {outer_for, inner_for}; ForPtr flattened = nullptr; @@ -3402,8 +3298,6 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) { } TEST(LoopNest, FlattenSimpleLoopNest3D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // for (int j = 0; j < 5; j++) { @@ -3420,7 +3314,7 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) { auto for1 = For::make(k, 0, 7, for_body); auto for2 = For::make(j, 0, 5, for1); auto for3 = For::make(i, 0, 10, for2); - Block::make({for3}); + auto parent_block = Block::make({for3}); std::vector loops = {for3, for2, for1}; ForPtr flattened = nullptr; @@ -3449,8 +3343,6 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) { } TEST(LoopNest, FlattenLoopNestAfterNormalize) { - KernelScope kernel_scope; - // Input IR: // for (int i = 2; i < 10; i++) { // for (int j = 3; j < 15; j++) { @@ -3463,7 +3355,7 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) { auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); auto inner_for = For::make(j, 3, 15, for_body); auto outer_for = For::make(i, 2, 10, inner_for); - Block::make({outer_for}); + auto parent_block = Block::make({outer_for}); std::vector loops = {outer_for, inner_for}; ForPtr flattened = nullptr; @@ -3492,8 +3384,6 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) { } TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 15-5; i++) { // for (int j = 0; j < 20/4; j++) { @@ -3534,8 +3424,6 @@ TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { } TEST(LoopNest, FlattenImperfectLoopNest) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // A[i, i] = 0; @@ -3565,8 +3453,6 @@ TEST(LoopNest, FlattenImperfectLoopNest) { } TEST(LoopNest, FlattenReductionLoopNest) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // S[i] = 0; @@ -3598,18 +3484,17 @@ TEST(LoopNest, FlattenReductionLoopNest) { } TEST(LoopNest, FlattenReductionLoopNestFromTensor) { - KernelScope kernel_scope; const int M = 3; const int N = 7; VarHandle m("m", kInt); VarHandle n("n", kInt); Placeholder b(BufHandle("b", {m, n}, kFloat)); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); LoopNest loop({c}); HashProvider hasher; auto hash_before = hasher.hash(loop.root_stmt()); - auto loops = loop.getAllLoopNestsWritingToBuf(c->buf())[1]; + auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; ForPtr flattened = nullptr; ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); ASSERT_EQ(flattened, nullptr); @@ -3618,8 +3503,6 @@ TEST(LoopNest, FlattenReductionLoopNestFromTensor) { } TEST(LoopNest, FlattenIncorrectLoopsAsInput) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 10; i++) { // for (int j = 0; j < 5; j++) { @@ -3658,42 +3541,39 @@ TEST(LoopNest, FlattenIncorrectLoopsAsInput) { } TEST(LoopNest, DetectInlineRankMismatch) { - KernelScope kernel_scope; const int kTotalSize = 8; Placeholder a_buf(BufHandle("A", {ExprHandle(kTotalSize)}, kFloat)); - Tensor* a = Compute("a", {{kTotalSize, "i"}}, [&](const VarHandle& i) { + Tensor a = Compute("a", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return a_buf.load(i); }); - Tensor* reshape = Compute( + Tensor reshape = Compute( "reshape", {{kTotalSize / 2, "i"}, {2, "j"}}, - [&](const VarHandle& i, const VarHandle& j) { return a->load(i, j); }); + [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); LoopNest l({reshape}, {a, reshape}); ASSERT_THROWS_WITH( l.computeInline(l.getLoopBodyFor(a)), - "Placeholder indexed access is inconsistent with its rank"); + "Number of indices doesn't match buf rank in the fuser."); } TEST(LoopNest, CacheReadsSimple) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 30, j + 3); + return A.load(i + 30, j + 3); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; - LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -3712,7 +3592,7 @@ TEST(LoopNest, CacheReadsSimple) { #CHECK: A_local[j_1] = A[ #CHECK: } #CHECK: for (int j_2 -#CHECK: B[10 * i_1 + j_2] = A_local[j_2]; +#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; #CHECK: } #CHECK: } #CHECK: for (int i_2 @@ -3744,24 +3624,22 @@ TEST(LoopNest, CacheReadsSimple) { } TEST(LoopNest, CacheReadsOuter) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 30, j + 40) + A->load(i + 31, j + 41); + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}, {A, B, C}); - StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][0]; - LoopNest::cacheAccesses(A->buf(), "A_local", i_loop); + StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; + LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -3769,7 +3647,7 @@ TEST(LoopNest, CacheReadsOuter) { checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] #CHECK: A_local[j_1 + 11 * i_1] = -#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]); +#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); )IR"); std::vector b_data(200, 0); @@ -3792,31 +3670,29 @@ TEST(LoopNest, CacheReadsOuter) { } TEST(LoopNest, CacheReadsInternal) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 30, j + 40) + A->load(i + 31, j + 41); + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}, {A, B, C}); - StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B->buf())[0][1]; - LoopNest::cacheAccesses(A->buf(), "A_local", j_loop); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] #CHECK: A_local[j_1 + 11 * i_2] = -#CHECK: B[10 * i_1 + j_2] = (A_local[j_2 + 12]) + (A_local[j_2]); +#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); )IR"); std::vector b_data(200, 0); @@ -3839,32 +3715,30 @@ TEST(LoopNest, CacheReadsInternal) { } TEST(LoopNest, CacheReadsInner) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); // note im changing the offset of the first arg of the first call to A. - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 34, j + 40) + A->load(i + 30, j + 41); + return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}, {A, B, C}); StmtPtr body = l.getLoopBodyFor(B); - LoopNest::cacheAccesses(A->buf(), "A_local", body); + LoopNest::cacheAccesses(A.buf(), "A_local", body); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); checkIR(result, R"IR( #CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] -#CHECK: A_local[2 * i_2 + j_2] = -#CHECK: B[10 * i_1 + j_1] = (A_local[1]) + (A_local[8]); +#CHECK: A_local[j_2 + 2 * i_2] = +#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); )IR"); std::vector b_data(200, 0); @@ -3887,24 +3761,22 @@ TEST(LoopNest, CacheReadsInner) { } TEST(LoopNest, CacheWritesSimple) { - KernelScope kernel_scope; - - Tensor* A = Compute( + Tensor A = Compute( "A", {{64, "i"}, {64, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i * j; }); - Tensor* B = Compute( + Tensor B = Compute( "B", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 30, j + 40) + A->load(i + 31, j + 41); + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); }); - Tensor* C = Compute( + Tensor C = Compute( "C", {{20, "i"}, {10, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i + 10, j + 20) + A->load(i + 30, j + 40); + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); }); LoopNest l({B, C}, {A, B, C}); - StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A->buf())[0][1]; - LoopNest::cacheAccesses(A->buf(), "A_local", a_loop); + StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -3914,7 +3786,7 @@ TEST(LoopNest, CacheWritesSimple) { #CHECK: for (int j = 0; j < 64 #CHECK: A_local[j] = i * j; #CHECK: for (int j_1 = 0; j_1 < 64 -#CHECK: A[64 * i + j_1] = A_local[ +#CHECK: A[j_1 + 64 * i] = A_local[ #CHECK: Free(A_local); #CHECK-NOT: A_local )IR"); @@ -3939,7 +3811,6 @@ TEST(LoopNest, CacheWritesSimple) { } TEST(LoopNest, DeadStoreElimination) { - KernelScope kernel_scope; VarHandle y("y", kInt); VarHandle x("x_tail", kInt); BufHandle f("f", {26, 5}, kInt); @@ -3980,7 +3851,6 @@ TEST(LoopNest, DeadStoreElimination) { } TEST(LoopNest, DeadStoreEliminationWithIntermediates) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -4006,7 +3876,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { // Will eliminate the write to g, but not f since it used by the producer of // h. - LoopNest loop(stmt, {h.node()}); + LoopNest loop(Stmt::clone(stmt), {h.node()}); loop.eliminateDeadStores(); checkIR(loop.root_stmt(), R"IR( @@ -4027,8 +3897,6 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { } TEST(LoopNest, CompoundTensorSimple) { - KernelScope kernel_scope; - BufHandle a_buf("A", {10, 5}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -4043,7 +3911,7 @@ TEST(LoopNest, CompoundTensorSimple) { auto outer_for2 = For::make(x, 0, 10, inner_for2); BlockPtr body = Block::make({outer_for1, outer_for2}); - Tensor* A = new Tensor(a_buf.node(), body); + Tensor A = Tensor(a_buf.node(), body); LoopNest l({A}); l.prepareForCodegen(); @@ -4066,30 +3934,27 @@ TEST(LoopNest, CompoundTensorSimple) { } TEST(LoopNest, InlineConstantIndex) { - KernelScope kernel_scope; const int N = 10; Placeholder x_buf("a", kFloat, {1, N, 1}); - Tensor* y = Compute( + Tensor y = Compute( "f", {{1, "m"}, {N, "n"}, {1, "o"}}, [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { return x_buf.load(m, n, o); }); - Tensor* z = Compute( + Tensor z = Compute( "f", {{1, "m"}, {N, "n"}, {1, "o"}}, [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { - return y->load(m, n, o); + return y.load(m, n, o); }); LoopNest l({z}, {y, z}); l.simplify(); - ASSERT_TRUE(l.computeInline(y->buf())); + ASSERT_TRUE(l.computeInline(y.buf())); } TEST(LoopNest, CompoundTensorUsed) { - KernelScope kernel_scope; - BufHandle a_buf("A", {10, 5}, kInt); VarHandle i("i", kInt); VarHandle j("j", kInt); @@ -4104,14 +3969,14 @@ TEST(LoopNest, CompoundTensorUsed) { auto outer_for2 = For::make(x, 0, 10, inner_for2); BlockPtr body = Block::make({outer_for1, outer_for2}); - Tensor* A = new Tensor(a_buf.node(), body); - Tensor* B = Compute( + Tensor A = Tensor(a_buf.node(), body); + Tensor B = Compute( "B", {{10, "i"}, {3, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return A->load(i, j + 1) + A->load(i, j + 2); + return A.load(i, j + 1) + A.load(i, j + 2); }); LoopNest l({B}, {A, B}); - ASSERT_FALSE(l.computeInline(A->buf())); + ASSERT_FALSE(l.computeInline(A.buf())); l.prepareForCodegen(); std::vector a_data(50, 0); @@ -4134,8 +3999,6 @@ TEST(LoopNest, CompoundTensorUsed) { } TEST(LoopNest, InlineFromLoad) { - KernelScope kernel_scope; - constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); @@ -4160,8 +4023,6 @@ TEST(LoopNest, InlineFromLoad) { } TEST(LoopNest, OptimizeConditionalsSimple) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) @@ -4202,8 +4063,6 @@ TEST(LoopNest, OptimizeConditionalsSimple) { } TEST(LoopNest, OptimizeConditionalsNestedConditions) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) @@ -4251,8 +4110,6 @@ TEST(LoopNest, OptimizeConditionalsNestedConditions) { } TEST(LoopNest, OptimizeConditionalsMultipleStores) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) @@ -4311,8 +4168,6 @@ TEST(LoopNest, OptimizeConditionalsMultipleStores) { } TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 50; i++) { // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) @@ -4365,8 +4220,6 @@ TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { } TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -4410,8 +4263,6 @@ TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { } TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) @@ -4451,8 +4302,6 @@ TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { } TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(i5, B[i], C[i-5]), D[i-10]) @@ -4534,8 +4381,6 @@ TEST(LoopNest, OptimizeConditionalsInvalidCondition) { } TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = IfThenElse(10, Tensor*> colReduce( - int M, - int N) { +static std::pair, Tensor> colReduce(int M, int N) { auto a = std::make_unique("a", kFloat, std::vector{M, N}); - Tensor* t = Reduce( + Tensor t = Reduce( "b", {{N, "n"}}, Sum(), @@ -4710,10 +4547,10 @@ static std::pair, Tensor*> colReduce( return {std::move(a), t}; } -static StmtPtr splitTailReorder(Tensor* b) { +static StmtPtr splitTailReorder(Tensor b) { constexpr int kVectorWidth = 8; LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[0]; + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; nest.splitWithTail(loops[0], kVectorWidth); // Now the loopnests will look like: // @@ -4734,24 +4571,24 @@ static StmtPtr splitTailReorder(Tensor* b) { // Write #2: "b[n_outer * 8 + n_inner] = ReduceOp(...)" // Loopnest #2: {n_outer, n_inner, m}; // We will have to reorder n_inner and m. - auto loopnests = nest.getAllLoopNestsWritingToBuf(b->buf()); + auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); nest.prepareForCodegen(); return nest.root_stmt(); } -static StmtPtr splitMaskReorder(Tensor* b) { +static StmtPtr splitMaskReorder(Tensor b) { constexpr int kVectorWidth = 8; LoopNest nest({b}); - auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[1]; + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; nest.splitWithMask(loops[0], kVectorWidth); - loops = nest.getAllLoopNestsWritingToBuf(b->buf())[1]; + loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; LoopNest::reorderAxis(loops[1], loops[2]); nest.prepareForCodegen(); return nest.root_stmt(); } -static void checkColReduce(StmtPtr s, Placeholder& p, Tensor* t) { +static void checkColReduce(StmtPtr s, Placeholder& p, Tensor t) { int M = immediateAs(p.dim(0)); int N = immediateAs(p.dim(1)); PaddedBuffer a(M, N); @@ -4773,7 +4610,6 @@ static void checkColReduce(StmtPtr s, Placeholder& p, Tensor* t) { } TEST(LoopNest, ColReduceSplitTailEvenReorder) { - KernelScope kernel_scope; constexpr int M = 76, N = 128; auto p = colReduce(M, N); StmtPtr s = splitTailReorder(p.second); @@ -4796,7 +4632,6 @@ TEST(LoopNest, ColReduceSplitTailEvenReorder) { } TEST(LoopNest, ColReduceSplitTailUnevenReorder) { - KernelScope kernel_scope; constexpr int M = 76, N = 100; auto p = colReduce(M, N); StmtPtr s = splitTailReorder(p.second); @@ -4822,7 +4657,6 @@ TEST(LoopNest, ColReduceSplitTailUnevenReorder) { } TEST(LoopNest, ColReduceSplitMaskEvenReorder) { - KernelScope kernel_scope; constexpr int M = 76, N = 128; auto p = colReduce(M, N); StmtPtr s = splitMaskReorder(p.second); @@ -4830,7 +4664,6 @@ TEST(LoopNest, ColReduceSplitMaskEvenReorder) { } TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { - KernelScope kernel_scope; constexpr int M = 76, N = 100; auto p = colReduce(M, N); StmtPtr s = splitMaskReorder(p.second); @@ -4838,8 +4671,6 @@ TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { } TEST(LoopNest, ReorderAxisWithMultipleConds) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // if i > 5 { @@ -4879,17 +4710,16 @@ TEST(LoopNest, ReorderAxisWithMultipleConds) { } TEST(LoopNest, VectorizeUse) { - KernelScope kernel_scope; constexpr int N = 8; Placeholder a("a", kFloat, {N}); - Tensor* b = Compute( + Tensor b = Compute( "b", {{N, "n"}}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); - Tensor* c = Compute( - "c", {{N, "n"}}, [&](const VarHandle& n) { return b->load(n) + 2.0f; }); + Tensor c = Compute( + "c", {{N, "n"}}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); LoopNest nest({c}, {b, c}); - auto loops = nest.getAllLoopNestsWritingToBuf(b->buf())[0]; + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; ASSERT_TRUE(LoopNest::vectorize(loops[0])); - loops = nest.getAllLoopNestsWritingToBuf(c->buf())[0]; + loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; ASSERT_TRUE(LoopNest::vectorize(loops[0])); nest.prepareForCodegen(); // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) @@ -4904,19 +4734,18 @@ TEST(LoopNest, VectorizeUse) { } const char* int64Loop = R"IR( -# CHECK: for (int64_t n = 0; n < 12; n++) { -# CHECK: b[n] = (a[n]) + 1; +# CHECK: for (int64_t n = 0ll; n < 12ll; n++) { +# CHECK: b[n] = (a[n]) + 1ll; # CHECK: } )IR"; TEST(LoopNest, Int64Direct) { - KernelScope kernel_scope; - constexpr int64_t N = 12; Placeholder a("a", kLong, {N}); Placeholder b("b", kLong, {N}); VarHandle n("n", kLong); - StmtPtr s = For::make(n, 0, N, b.store({n}, a.load({n}) + LongImm::make(1l))); + StmtPtr s = For::make( + n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); s = IRSimplifier::simplify(s); std::ostringstream oss; oss << *s; @@ -4924,11 +4753,9 @@ TEST(LoopNest, Int64Direct) { } TEST(LoopNest, Int64Compute) { - KernelScope kernel_scope; - constexpr int64_t N = 12; Placeholder a("a", kLong, {N}); - Tensor* b = Compute("b", {{N, "n"}}, [&](const VarHandle& n) { + Tensor b = Compute("b", {{N, "n"}}, [&](const VarHandle& n) { return a.load(n) + LongImm::make(1l); }); LoopNest nest({b}); @@ -4940,8 +4767,6 @@ TEST(LoopNest, Int64Compute) { } TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = 0; @@ -5002,8 +4827,6 @@ TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { } TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = 0; @@ -5061,8 +4884,6 @@ TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { } TEST(LoopNest, DistributeLoopWithoutAnyPivot) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = 0; @@ -5123,8 +4944,6 @@ TEST(LoopNest, DistributeLoopWithoutAnyPivot) { } TEST(LoopNest, DistributeLoopOverInnerLoops) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = 0; @@ -5182,8 +5001,6 @@ TEST(LoopNest, DistributeLoopOverInnerLoops) { } TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { - KernelScope kernel_scope; - // Input IR: // for (int m = 0; m < 50; m++) { // for (int i = 0; i < 20; i++) { @@ -5293,8 +5110,6 @@ TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { } TEST(LoopNest, fuseLoopsSimple) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -5329,8 +5144,6 @@ TEST(LoopNest, fuseLoopsSimple) { } TEST(LoopNest, fuseLoopsMultiple) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; i++) { // A[i+100] = 20 + i; @@ -5372,8 +5185,6 @@ TEST(LoopNest, fuseLoopsMultiple) { } TEST(LoopNest, fuseLoopsNested) { - KernelScope kernel_scope; - // Input IR: // for (int m = 0; m < 20; m++) { // A[m] = 0; @@ -5434,8 +5245,6 @@ TEST(LoopNest, fuseLoopsNested) { } TEST(LoopNest, fuseLoopsNested2D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -5494,8 +5303,6 @@ TEST(LoopNest, fuseLoopsNested2D) { } TEST(LoopNest, fuseLoopsNested2DInner) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -5536,8 +5343,6 @@ TEST(LoopNest, fuseLoopsNested2DInner) { } TEST(LoopNest, fuseLoopsDifferentStopBounds) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -5559,8 +5364,6 @@ TEST(LoopNest, fuseLoopsDifferentStopBounds) { } TEST(LoopNest, fuseLoopsDifferentStartBounds) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -5582,8 +5385,6 @@ TEST(LoopNest, fuseLoopsDifferentStartBounds) { } TEST(LoopNest, fuseLoopsNotContiguous) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -5607,8 +5408,6 @@ TEST(LoopNest, fuseLoopsNotContiguous) { } TEST(LoopNest, fuseLoopsWithDifferentParents) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 50; i++) { // for (int j = 0; j < 100; j++) { @@ -5636,8 +5435,6 @@ TEST(LoopNest, fuseLoopsWithDifferentParents) { } TEST(LoopNest, fuseLoopsWithVariableBounds) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < N; j++) { // A[j] = 10 * j; @@ -5674,8 +5471,6 @@ TEST(LoopNest, fuseLoopsWithVariableBounds) { } TEST(LoopNest, fuseLoopsWithExprBounds) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < M + N; j++) { // A[j] = 10 * j; @@ -5712,8 +5507,6 @@ TEST(LoopNest, fuseLoopsWithExprBounds) { } TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { - KernelScope kernel_scope; - // Input IR: // for (int j = M; j < N * 2; j++) { // A[j] = 10 * j; @@ -5751,8 +5544,6 @@ TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { } TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -5788,8 +5579,6 @@ TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { } TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -5838,8 +5627,6 @@ TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { } TEST(LoopNest, fuseLoopsWithReductions) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // A[i] = 0 @@ -5886,8 +5673,6 @@ TEST(LoopNest, fuseLoopsWithReductions) { } TEST(LoopNest, fuseLoopsWith2DReductions) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 50; j++) { @@ -5946,8 +5731,6 @@ TEST(LoopNest, fuseLoopsWith2DReductions) { } TEST(LoopNest, fuseLoopsWithComplexIndices) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 20; j++) { @@ -5994,8 +5777,6 @@ TEST(LoopNest, fuseLoopsWithComplexIndices) { } TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 20; j++) { @@ -6025,8 +5806,6 @@ TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { } TEST(LoopNest, fuseLoopsWithTranspose) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 20; j++) { @@ -6056,8 +5835,6 @@ TEST(LoopNest, fuseLoopsWithTranspose) { } TEST(LoopNest, fuseLoopsThatViolateDependencies1) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -6079,8 +5856,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies1) { } TEST(LoopNest, fuseLoopsThatViolateDependencies2) { - KernelScope kernel_scope; - // Input IR: // for (int j = 10; j < 100; j++) { // A[j] = 10 * j; @@ -6102,8 +5877,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies2) { } TEST(LoopNest, fuseLoopsThatViolateDependencies3) { - KernelScope kernel_scope; - // Input IR: // for (int m = 0; m < 20; m++) { // A[m] = 0; @@ -6147,8 +5920,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies3) { } TEST(LoopNest, fuseLoopsThatViolateDependencies4) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -6191,8 +5962,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies4) { } TEST(LoopNest, fuseLoopsThatViolateDependencies5) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 100; j++) { @@ -6221,8 +5990,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies5) { } TEST(LoopNest, fuseLoopsThatViolateDependencies6) { - KernelScope kernel_scope; - // Input IR: // for (int j = 0; j < 100; j++) { // A[j] = 10 * j; @@ -6249,8 +6016,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies6) { } TEST(LoopNest, fuseLoopsThatViolateDependencies7) { - KernelScope kernel_scope; - // Input IR: // for (int k = 0; k < 100; k++) { // B[k] = 20 * A[99-k]; @@ -6277,8 +6042,6 @@ TEST(LoopNest, fuseLoopsThatViolateDependencies7) { } TEST(LoopNest, areLoopsPerfectlyNested) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6321,8 +6084,6 @@ TEST(LoopNest, areLoopsPerfectlyNested) { } TEST(LoopNest, reorderNestedLoops2D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6347,8 +6108,6 @@ TEST(LoopNest, reorderNestedLoops2D) { } TEST(LoopNest, reorderNestedLoops3D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6378,8 +6137,6 @@ TEST(LoopNest, reorderNestedLoops3D) { } TEST(LoopNest, reorderNestedLoops4D) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6417,8 +6174,6 @@ TEST(LoopNest, reorderNestedLoops4D) { } TEST(LoopNest, reorderTrivialPermutation) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6448,8 +6203,6 @@ TEST(LoopNest, reorderTrivialPermutation) { } TEST(LoopNest, reorderInvalidPermutations) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6487,8 +6240,6 @@ TEST(LoopNest, reorderInvalidPermutations) { } TEST(LoopNest, reorderInvalidLoopNest) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 20; i++) { // for (int j = 0; j < 30; j++) { @@ -6530,8 +6281,6 @@ TEST(LoopNest, reorderInvalidLoopNest) { } TEST(LoopNest, compressBufferSimple) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6576,8 +6325,6 @@ TEST(LoopNest, compressBufferSimple) { } TEST(LoopNest, compressBufferMultipleDims) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6616,8 +6363,6 @@ TEST(LoopNest, compressBufferMultipleDims) { } TEST(LoopNest, compressBufferMultipleDims2) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6666,8 +6411,6 @@ TEST(LoopNest, compressBufferMultipleDims2) { } TEST(LoopNest, compressBufferDifferentOrderIndices) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6712,8 +6455,6 @@ TEST(LoopNest, compressBufferDifferentOrderIndices) { } TEST(LoopNest, compressBufferVariableBounds) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < M; ++i) { // for (int j = 0; j < N; ++j) { @@ -6761,8 +6502,6 @@ TEST(LoopNest, compressBufferVariableBounds) { } TEST(LoopNest, compressBufferNoCommonParentLoops) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6812,8 +6551,6 @@ TEST(LoopNest, compressBufferNoCommonParentLoops) { } TEST(LoopNest, compressBufferIndicesMixed) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { @@ -6860,8 +6597,6 @@ TEST(LoopNest, compressBufferIndicesMixed) { } TEST(LoopNest, compressMultipleBuffers) { - KernelScope kernel_scope; - // Input IR: // for (int i = 0; i < 100; ++i) { // for (int j = 0; j < 200; ++j) { diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index 7f844c5ba4cf4..c9990dcacfb41 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -19,8 +19,6 @@ using namespace torch::jit::tensorexpr; // larger and fully encloses B, while ContainedOrEqual is the reverse. Equal // ranges are ContainedOrEqual. TEST(MemDependency, BoundOverlap) { - KernelScope kernel_scope; - using namespace analysis; auto CB = [](int s, int e) { @@ -79,7 +77,6 @@ TEST(MemDependency, BoundOverlap) { } TEST(MemDependency, BoundOverlapSymbolic) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -116,8 +113,6 @@ TEST(MemDependency, BoundOverlapSymbolic) { // This uses boundOverlap on each dimension and return the "lowest" kind of // overlap. TEST(MemDependency, BoundOverlapMultiDim) { - KernelScope kernel_scope; - using namespace analysis; auto CB = [](int s, int e) { @@ -189,8 +184,6 @@ TEST(MemDependency, BoundOverlapMultiDim) { // Test the helper we use to subtract bounds: returns the regions(s) of A which // remain after removing the region of B. TEST(MemDependency, BoundSubtract) { - KernelScope kernel_scope; - using namespace analysis; auto CB = [](int s, int e) { @@ -224,7 +217,6 @@ TEST(MemDependency, BoundSubtract) { } TEST(MemDependency, BoundSubtractSymbolic) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -273,8 +265,6 @@ TEST(MemDependency, BoundSubtractSymbolic) { // Tests the helper function that does subtraction, but for multi dimensional // indices bounds. TEST(MemDependency, BoundSubtractMultiDim) { - KernelScope kernel_scope; - using namespace analysis; auto CB = [](int s, int e) { @@ -335,7 +325,6 @@ TEST(MemDependency, BoundSubtractMultiDim) { // Tests the multi dimensional subtraction code for bounds that cannot be fully // materialized. TEST(MemDependency, BoundSubtractMultiDimSymbolic) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -403,7 +392,6 @@ TEST(MemDependency, BoundSubtractMultiDimSymbolic) { // Simple check that the analyzer does anything at all... TEST(MemDependency, MemDependencyCheckerSimple) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); @@ -429,7 +417,6 @@ TEST(MemDependency, MemDependencyCheckerSimple) { // Check that there is a difference between direct and indirect dependence. TEST(MemDependency, MemDependencyCheckerMultiStmt) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); BufHandle c("C", {1}, kInt); @@ -466,7 +453,6 @@ TEST(MemDependency, MemDependencyCheckerMultiStmt) { // Verify that we do filter writes that are totally overlapped by later writes. TEST(MemDependency, MemDependencyCheckerOverlap) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); @@ -499,7 +485,6 @@ TEST(MemDependency, MemDependencyCheckerOverlap) { // Verify that bounds match loop iterations, and that dependencies progress // across loop scopes. TEST(MemDependency, MemDependencyCheckerLoop) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); VarHandle x("x", kInt); @@ -541,7 +526,6 @@ TEST(MemDependency, MemDependencyCheckerLoop) { // Reductions should promote dependencies as well. TEST(MemDependency, MemDependencyCheckerLoopReduce) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -600,7 +584,6 @@ TEST(MemDependency, MemDependencyCheckerLoopReduce) { // Lowering a reduction doesn't affect dependency analysis. TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -654,7 +637,6 @@ TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { // Can determine dependencies of outputs, through to inputs. TEST(MemDependency, MemDependencyCheckerInputsOutputs) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -707,7 +689,6 @@ TEST(MemDependency, MemDependencyCheckerInputsOutputs) { // Can tell if an output does not depend on an input. TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -746,7 +727,6 @@ TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { // Verify different loop extents produce accesses with different bounds, and // that later accesses find dependencies that overlap their entire bound range. TEST(MemDependency, MemDependencyCheckerLoopBounds) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); BufHandle c("C", {10}, kInt); @@ -928,7 +908,6 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { // Verify that we can still infer bounds when the loop var is offset. TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -1111,7 +1090,6 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { // iteration. This is affected by whether or not we can trust the execution // order of the loop. TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); VarHandle x("x", kInt); @@ -1749,7 +1727,6 @@ TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { // TODO: actually this only works because of the size of the ranges, revist this // test after strided overlap is implemented. TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - KernelScope kernel_scope; BufHandle a("A", {20}, kInt); BufHandle b("B", {20}, kInt); VarHandle x("x", kInt); @@ -1775,7 +1752,6 @@ TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { /* TODO(nickg) - this test will fail due to the lack of stride math in Bound TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { - KernelScope kernel_scope; BufHandle a("A", {20}, kInt); BufHandle b("B", {20}, kInt); BufHandle c("C", {10}, kInt); @@ -1806,7 +1782,6 @@ TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { // analysis on Stmts using Cond. TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); BufHandle c("C", {10}, kInt); @@ -2002,7 +1977,6 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { // Stmts using IfThenElse. TEST(MemDependency, MemDependencyCheckerIfThenElse) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); BufHandle c("C", {10}, kInt); @@ -2112,7 +2086,6 @@ TEST(MemDependency, MemDependencyCheckerIfThenElse) { // Cutting a loop with single elem writes TEST(MemDependency, MemDependencyCheckerCutLoop) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -2194,7 +2167,6 @@ TEST(MemDependency, MemDependencyCheckerCutLoop) { // Dynamic shapes (load in indices). TEST(MemDependency, MemDependencyCheckerDynamicShapes) { - KernelScope kernel_scope; BufHandle a("A", {100}, kInt); BufHandle b("B", {100}, kInt); BufHandle c("C", {100}, kInt); @@ -2436,7 +2408,6 @@ TEST(MemDependency, MemDependencyCheckerDynamicShapes) { // Verify multi dimensional bounds work. TEST(MemDependency, MemDependencyCheckerMultiDim) { - KernelScope kernel_scope; int M = 10, N = 9, K = 12; BufHandle a("A", {M, N, K}, kInt); BufHandle b("B", {M, N, K}, kInt); @@ -2703,8 +2674,6 @@ TEST(MemDependency, MemDependencyCheckerMultiDim) { // Various tests using the external Compute/Reduce API. TEST(MemDependency, MemDependencyCheckerComputeAPI) { - KernelScope kernel_scope; - using namespace analysis; /* for (int m = 0; m < 4; m++) { @@ -2726,28 +2695,28 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) { // Can determine if 2 loops created by Compute are dependent. Placeholder a_buf("a", kFloat, {4, 5}); Placeholder b_buf("b", kFloat, {5, 6}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) + b_buf.load(n, k); }); - Tensor* d = Compute( + Tensor d = Compute( "d", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c->load(m, n, k) + 1; + return c.load(m, n, k) + 1; }); LoopNest l({d}, {c, d}); - MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); + MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d.buf()}); l.root_stmt()->accept(&analyzer); // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a_buf.data())); - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.data())); // Second loop depends on first loop. auto c_loop = l.getLoopStmtsFor(c)[0]; @@ -2756,8 +2725,6 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) { } TEST(MemDependency, MemDependencyCheckerComputeInline) { - KernelScope kernel_scope; - using namespace analysis; /* for (int m = 0; m < 4; m++) { @@ -2773,44 +2740,42 @@ TEST(MemDependency, MemDependencyCheckerComputeInline) { Placeholder a_buf("a", kFloat, {4, 5}); Placeholder b_buf("b", kFloat, {5, 6}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n) + b_buf.load(n, k); }); - Tensor* d = Compute( + Tensor d = Compute( "d", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { - return c->load(m, n, k) + 1; + return c.load(m, n, k) + 1; }); LoopNest l({d}, {c, d}); - l.computeInline(c->buf()); + l.computeInline(c.buf()); - MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()}); + MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d.buf()}); l.root_stmt()->accept(&analyzer); // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a_buf.data())); - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.data())); // broadcast_add tensor should not appear in trace at all. for (auto& wi : analyzer.getHistory()) { - ASSERT_NE(wi->var(), c->buf()->base_handle()); + ASSERT_NE(wi->var(), c.buf()->base_handle()); } } TEST(MemDependency, MemDependencyCheckerComputeSplit) { - KernelScope kernel_scope; - using namespace analysis; // Split an axis, so the number of loops != the number of dimensions. Placeholder a_buf("a", kFloat, {4, 5}); Placeholder b_buf("b", kFloat, {5, 6}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { @@ -2819,13 +2784,12 @@ TEST(MemDependency, MemDependencyCheckerComputeSplit) { LoopNest l({c}); - MemDependencyChecker analyzer_before( - {a_buf.data(), b_buf.data()}, {c->buf()}); + MemDependencyChecker analyzer_before({a_buf.data(), b_buf.data()}, {c.buf()}); l.root_stmt()->accept(&analyzer_before); l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); - MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); + MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c.buf()}); StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); stmt->accept(&analyzer_after); @@ -2852,14 +2816,12 @@ TEST(MemDependency, MemDependencyCheckerComputeSplit) { } TEST(MemDependency, MemDependencyCheckerComputeReorder) { - KernelScope kernel_scope; - using namespace analysis; // Reorder an axis, so the loop order doesn't match the indexing order. Placeholder a_buf("a", kFloat, {4, 5}); Placeholder b_buf("b", kFloat, {5, 6}); - Tensor* c = Compute( + Tensor c = Compute( "broadcast_add", {{4, "m"}, {5, "n"}, {6, "k"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { @@ -2868,14 +2830,13 @@ TEST(MemDependency, MemDependencyCheckerComputeReorder) { LoopNest l({c}); - MemDependencyChecker analyzer_before( - {a_buf.data(), b_buf.data()}, {c->buf()}); + MemDependencyChecker analyzer_before({a_buf.data(), b_buf.data()}, {c.buf()}); l.root_stmt()->accept(&analyzer_before); auto loops = l.getLoopStmtsFor(c); l.reorderAxis(loops[0], loops[1]); - MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c->buf()}); + MemDependencyChecker analyzer_after({a_buf.data(), b_buf.data()}, {c.buf()}); StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); stmt->accept(&analyzer_after); @@ -2902,8 +2863,6 @@ TEST(MemDependency, MemDependencyCheckerComputeReorder) { } TEST(MemDependency, MemDependencyCheckerComputeReduce) { - KernelScope kernel_scope; - using namespace analysis; /* for (int l2 = 0; l2 < 2; l2++) { * for (int n1 = 0; n1 < 3; n1++) { @@ -2928,22 +2887,22 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) { Placeholder a(BufHandle("a", {2, 3, 6}, kFloat)); Placeholder b(BufHandle("b", {2, 3, 6}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{2, "l2"}, {3, "n1"}, {6, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}}); + Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}}); LoopNest l({d}, {c, d}); - MemDependencyChecker analyzer({a.data(), b.data()}, {d->buf()}); + MemDependencyChecker analyzer({a.data(), b.data()}, {d.buf()}); l.root_stmt()->accept(&analyzer); // Sanity test: Output depends on input. - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), a.data())); - ASSERT_TRUE(analyzer.dependsIndirectly(d->buf(), b.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.data())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.data())); // Second loop depends on first loop. auto c_loop = l.getLoopStmtsFor(c)[0]; @@ -2957,7 +2916,6 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) { } TEST(MemDependency, MemDependencyCheckerComputeGEMM) { - KernelScope kernel_scope; int M = 1024; int N = 1024; int K = 2048; @@ -2965,7 +2923,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { Placeholder AP(BufHandle("A", {M, K}, kFloat)); Placeholder BP(BufHandle("B", {K, N}, kFloat)); - Tensor* CT = Reduce( + Tensor CT = Reduce( "gemm", {{M, "M"}, {N, "N"}}, Sum(), @@ -3011,7 +2969,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { } { auto const& loops = loop.getLoopStmtsFor(CT); - loop.cacheAccesses(CT->buf(), "C_regs", loops[2]); + loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); } MemDependencyChecker analyzer_unlowered( @@ -3026,12 +2984,12 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) { stmt->accept(&analyzer_unlowered); // Outputs depend on inputs. - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT->buf(), AP.data())); - ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT->buf(), BP.data())); + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.data())); + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.data())); // The last write to gemm should cover the total bound of the output. std::shared_ptr outputAccess = - analyzer_unlowered.output(CT->buf()); + analyzer_unlowered.output(CT.buf()); // A single dependency. ASSERT_EQ(outputAccess->dependencies().size(), 1); diff --git a/test/cpp/tensorexpr/test_ops.cpp b/test/cpp/tensorexpr/test_ops.cpp index 674dbd9cb0199..586c093e213d1 100644 --- a/test/cpp/tensorexpr/test_ops.cpp +++ b/test/cpp/tensorexpr/test_ops.cpp @@ -6,7 +6,7 @@ using namespace torch::jit::tensorexpr; -using Tensors = std::vector; +using Tensors = std::vector; using Args = std::vector; std::unique_ptr compile( const Args& inputs, @@ -20,15 +20,13 @@ std::unique_ptr compile( } TEST(Ops, Sum) { - KernelScope ks; - std::vector testDims = {{0}, {1}, {0, 1}}; for (auto const& dims : testDims) { constexpr int M = 8; constexpr int N = 16; Placeholder a("a", kFloat, {M, N}); - Tensor* b = computeSum({a.handle(), dims, false}, c10::kFloat); + Tensor b = computeSum({a.handle(), dims, false}, c10::kFloat); auto cg = compile({a}, {b}); auto at = at::arange(M * N, at::kFloat).view({M, N}); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index bd71a4fd8da14..3d2c0ecc27bfe 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -24,7 +24,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(Reductions, ReduceSum0D_1) { - KernelScope kernel_scope; const int M = 10; Placeholder b(BufHandle("b", {M}, kFloat)); @@ -35,7 +34,7 @@ TEST(Reductions, ReduceSum0D_1) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -50,7 +49,6 @@ TEST(Reductions, ReduceSum0D_1) { } TEST(Reductions, ReduceSum0D_2) { - KernelScope kernel_scope; const int M = 10; Placeholder b(BufHandle("b", {}, kFloat)); @@ -59,7 +57,7 @@ TEST(Reductions, ReduceSum0D_2) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {}); + Tensor c = Reduce("sum", {}, Sum(), b, {}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -73,8 +71,6 @@ TEST(Reductions, ReduceSum0D_2) { // Sum an array to a single value. TEST(Reductions, ReduceSum1D) { - KernelScope kernel_scope; - Placeholder b(BufHandle("b", {10}, kFloat)); std::vector in(10); for (int j = 0; j < 10; ++j) { @@ -83,7 +79,7 @@ TEST(Reductions, ReduceSum1D) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{10, "m"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{10, "m"}}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -96,8 +92,6 @@ TEST(Reductions, ReduceSum1D) { } // Sum a 2D tensor to a 1D tensor with dynamic shapes. TEST(Reductions, ReduceSum2D) { - KernelScope kernel_scope; - const int M = 3; const int N = 7; @@ -114,7 +108,7 @@ TEST(Reductions, ReduceSum2D) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -138,14 +132,12 @@ TEST(Reductions, ReduceSum2D) { // Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to // check our work. TEST(Reductions, ReduceSum3D) { - KernelScope kernel_scope; - const int M = 10; VarHandle m("m", kInt); Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); - Tensor* c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}}); + Tensor c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(), b, {{m, "m"}}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -175,7 +167,7 @@ TEST(Reductions, ReduceSum3D) { ASSERT_EQ(cData[i], expected); } - Tensor* d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}}); + Tensor d = Reduce("sum2", {{2, "l"}}, Sum(), b, {{3, "n"}, {m, "m"}}); LoopNest loop2({d}); loop2.prepareForCodegen(); StmtPtr s2 = loop2.root_stmt(); @@ -192,8 +184,8 @@ TEST(Reductions, ReduceSum3D) { } // This is the same as just reducing the original result across that axis. - Placeholder c_buf(BufHandle(c->buf())); - Tensor* e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}}); + Placeholder c_buf(BufHandle(c.buf())); + Tensor e = Reduce("sum3", {{2, "l"}}, Sum(), c_buf, {{3, "m"}}); LoopNest loop3({e}); loop3.prepareForCodegen(); StmtPtr s3 = loop3.root_stmt(); @@ -209,8 +201,6 @@ TEST(Reductions, ReduceSum3D) { // Sum a large (10 D) Tensor 5 dimensions in. TEST(Reductions, ReduceSum10D) { - KernelScope kernel_scope; - Placeholder in_(BufHandle("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat)); const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; Placeholder out_(BufHandle("out_", {2, 3, 2, 3, 2}, kFloat)); @@ -219,7 +209,7 @@ TEST(Reductions, ReduceSum10D) { std::vector in(InputSize, 1.f); std::vector out(OutputSize, -1.f); - Tensor* c = Reduce( + Tensor c = Reduce( "sum", {{2, "a"}, {3, "b"}, {2, "c"}, {3, "d"}, {2, "e"}}, Sum(), @@ -243,8 +233,6 @@ TEST(Reductions, ReduceSum10D) { // Reduce via Mul rather than Add using a custom Reducer. TEST(Reductions, ReduceProduct) { - KernelScope kernel_scope; - const int M = 4; const int N = 4; @@ -261,7 +249,7 @@ TEST(Reductions, ReduceProduct) { Reducer product( ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); - Tensor* c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}}); + Tensor c = Reduce("product", {{M, "m"}}, product, b, {{N, "n"}}); LoopNest loop({c}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -284,8 +272,6 @@ TEST(Reductions, ReduceProduct) { // Maximum reductions. TEST(Reductions, ReduceMax) { - KernelScope kernel_scope; - Placeholder in_(BufHandle("b", {10}, kFloat)); std::vector in(10); @@ -294,7 +280,7 @@ TEST(Reductions, ReduceMax) { in[j] = j; } - Tensor* dm1 = Reduce("max", {}, Maximum(kFloat), in_, {{10, "m"}}); + Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {{10, "m"}}); LoopNest loop({dm1}); loop.prepareForCodegen(); @@ -309,7 +295,7 @@ TEST(Reductions, ReduceMax) { Placeholder in2_(BufHandle("b", {2, 5}, kFloat)); std::vector out2(2, -1.f); - Tensor* m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}}); + Tensor m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}}); LoopNest loop2({m2d}); loop2.prepareForCodegen(); @@ -325,8 +311,6 @@ TEST(Reductions, ReduceMax) { // Minimum reduction, with custom initialization. TEST(Reductions, ReduceMinCustomInitializer) { - KernelScope kernel_scope; - VarHandle minInit("minInit", kFloat); Placeholder in_(BufHandle("b", {10}, kFloat)); @@ -336,7 +320,7 @@ TEST(Reductions, ReduceMinCustomInitializer) { in[j] = 10 + j; } - Tensor* min = Reduce( + Tensor min = Reduce( "min", {}, Minimum(ExprHandle(minInit)), @@ -363,8 +347,6 @@ TEST(Reductions, ReduceMinCustomInitializer) { // Example implementation of Any/All. // TODO: this is very awkward without logical And/Or operators. TEST(Reductions, ReduceAnyAll) { - KernelScope kernel_scope; - VarHandle searchValue("searchValue", kInt); Placeholder b(BufHandle("b", {4, 10}, kInt)); @@ -372,7 +354,7 @@ TEST(Reductions, ReduceAnyAll) { return CompareSelect::make(a, 1, 1, b, kEQ); }); - Tensor* any = Reduce( + Tensor any = Reduce( "anyEqual", {{4, "i"}}, anyEqSV, @@ -415,7 +397,7 @@ TEST(Reductions, ReduceAnyAll) { return CompareSelect::make(a, 0, 0, b, kEQ); }); - Tensor* allGreaterThan = Reduce( + Tensor allGreaterThan = Reduce( "allGreaterThan", {{4, "i"}}, allGTSV, @@ -449,8 +431,6 @@ TEST(Reductions, ReduceAnyAll) { } TEST(Reductions, ReduceMatmul2D) { - KernelScope kernel_scope; - Placeholder tA(BufHandle("tA", {3, 2}, kFloat)); Placeholder tB(BufHandle("tB", {2, 3}, kFloat)); @@ -465,7 +445,7 @@ TEST(Reductions, ReduceMatmul2D) { } } - Tensor* mm = Reduce( + Tensor mm = Reduce( "mm", {{3, "m"}, {3, "n"}}, Sum(), @@ -491,8 +471,6 @@ TEST(Reductions, ReduceMatmul2D) { } TEST(Reductions, ReduceRfactorLike) { - KernelScope kernel_scope; - Placeholder in(BufHandle("in", {10, 10}, kFloat)); std::vector in_(100); for (int i = 0; i < 100; ++i) { @@ -501,10 +479,10 @@ TEST(Reductions, ReduceRfactorLike) { std::vector in_rf_(10, -2.f); std::vector out(1, -1.f); - Tensor* l1 = Reduce("l1", {{10, "i"}}, Sum(), in, {{10, "j"}}); - Placeholder in_rf(BufHandle(l1->buf())); + Tensor l1 = Reduce("l1", {{10, "i"}}, Sum(), in, {{10, "j"}}); + Placeholder in_rf(BufHandle(l1.buf())); - Tensor* l2 = Reduce("l2", {}, Sum(), in_rf, {{10, "i"}}); + Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {{10, "i"}}); LoopNest loop({l1, l2}); loop.prepareForCodegen(); @@ -518,20 +496,18 @@ TEST(Reductions, ReduceRfactorLike) { } TEST(Reductions, ReduceAsProducer) { - KernelScope kernel_scope; - const int M = 10; VarHandle m("m", kInt); Placeholder a(BufHandle("a", {2, 3}, kFloat)); Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); - Tensor* c = Reduce("sum", {{2, "l1"}, {3, "n1"}}, Sum(), b, {{m, "m1"}}); - Tensor* d = Compute( + Tensor c = Reduce("sum", {{2, "l1"}, {3, "n1"}}, Sum(), b, {{m, "m1"}}); + Tensor d = Compute( "scale", {{2, "l2"}, {3, "n1"}}, [&](const VarHandle& l, const VarHandle& n) { - return c->load(l, n) * a.load(l, n); + return c.load(l, n) * a.load(l, n); }); LoopNest loop({d}, {c, d}); loop.prepareForCodegen(); @@ -563,21 +539,19 @@ TEST(Reductions, ReduceAsProducer) { } TEST(Reductions, ReduceAsConsumer) { - KernelScope kernel_scope; - const int M = 10; VarHandle m("m", kInt); Placeholder a(BufHandle("a", {2, 3, m}, kFloat)); Placeholder b(BufHandle("b", {2, 3, m}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{2, "l2"}, {3, "n1"}, {m, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}}); + Tensor d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {m, "m1"}}); LoopNest loop({d}, {c, d}); loop.prepareForCodegen(); StmtPtr s = loop.root_stmt(); @@ -614,8 +588,6 @@ TEST(Reductions, ReduceAsConsumer) { } TEST(Reductions, SplitReduceAxis) { - KernelScope kernel_scope; - Placeholder in(BufHandle("in", {16, 8}, kFloat)); std::vector in_(16 * 8); @@ -626,7 +598,7 @@ TEST(Reductions, SplitReduceAxis) { } std::vector out(16, -1.f); - Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); + Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::splitWithTail(loops[1], 2); @@ -645,8 +617,6 @@ TEST(Reductions, SplitReduceAxis) { } TEST(Reductions, SplitNonReduceAxis) { - KernelScope kernel_scope; - Placeholder in(BufHandle("in", {16, 8}, kFloat)); std::vector in_(16 * 8); @@ -656,7 +626,7 @@ TEST(Reductions, SplitNonReduceAxis) { } } std::vector out(16, -1.f); - Tensor* tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); + Tensor tensor = Reduce("sum", {{16, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::splitWithTail(loops[0], 2); @@ -676,7 +646,6 @@ TEST(Reductions, SplitNonReduceAxis) { } TEST(Reductions, ReorderedReductionInitializer) { - KernelScope kernel_scope; /* From the quip: for k in 0..1: // blockIdx for m in 0..128: @@ -687,14 +656,14 @@ TEST(Reductions, ReorderedReductionInitializer) { Placeholder in(BufHandle("in", {1, 12, 6}, kFloat)); std::vector in_(12 * 6, 1.f); - Tensor* tensor_ = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); + Tensor tensor_ = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); LoopNest l_({tensor_}); l_.prepareForCodegen(); StmtPtr s_ = Stmt::clone(l_.root_stmt()); s_ = IRSimplifier::simplify(s_); - Tensor* tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); + Tensor tensor = Reduce("sum", {{1, "k"}, {12, "n"}}, Sum(), in, {{6, "m"}}); LoopNest l({tensor}); auto loops = l.getLoopStmtsFor(tensor); @@ -726,8 +695,6 @@ TEST(Reductions, ReorderedReductionInitializer) { } TEST(Reductions, ReduceRfactor) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; VarHandle m("m", kInt); @@ -741,10 +708,10 @@ TEST(Reductions, ReduceRfactor) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); @@ -759,8 +726,6 @@ TEST(Reductions, ReduceRfactor) { } TEST(Reductions, Reduce3DRfactorInner) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -776,10 +741,10 @@ TEST(Reductions, Reduce3DRfactorInner) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 1); @@ -794,8 +759,6 @@ TEST(Reductions, Reduce3DRfactorInner) { } TEST(Reductions, Reduce3DRfactorOuter) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -811,10 +774,10 @@ TEST(Reductions, Reduce3DRfactorOuter) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); @@ -828,8 +791,6 @@ TEST(Reductions, Reduce3DRfactorOuter) { } TEST(Reductions, ReduceRepeatedInternalRfactor) { - KernelScope kernel_scope; - Placeholder in_(BufHandle("in_", {2, 3, 4, 5, 6}, kFloat)); const int InputSize = 2 * 3 * 4 * 5 * 6; @@ -837,7 +798,7 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) { std::vector out(1, -1.f); std::vector ref(1, -1.f); - Tensor* c = Reduce( + Tensor c = Reduce( "sum", {}, Sum(), @@ -854,7 +815,7 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) { IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); ref_cg.call({in, ref}); - BufPtr tmp_buf = c->buf(); + BufPtr tmp_buf = c.buf(); for (int idx = 0; idx < rfac_number; idx++) { auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; @@ -875,8 +836,6 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) { // Split a reduction axis with a tail loop. TEST(Reductions, ReduceSplitTail) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -890,7 +849,7 @@ TEST(Reductions, ReduceSplitTail) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 8); @@ -908,8 +867,6 @@ TEST(Reductions, ReduceSplitTail) { // Split a reduction axis cleanly so there is no tail loop. TEST(Reductions, ReduceSplitNoTail) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -922,7 +879,7 @@ TEST(Reductions, ReduceSplitNoTail) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 5); @@ -941,8 +898,6 @@ TEST(Reductions, ReduceSplitNoTail) { // Split a reduction axis with only a tail loop (the split loop will be size 0 // and eliminated out). TEST(Reductions, ReduceOverSplitTail) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -956,7 +911,7 @@ TEST(Reductions, ReduceOverSplitTail) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[i], 16); @@ -974,8 +929,6 @@ TEST(Reductions, ReduceOverSplitTail) { // Split a reduction axis with a mask. TEST(Reductions, ReduceSplitMask) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -989,7 +942,7 @@ TEST(Reductions, ReduceSplitMask) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 8); @@ -1007,8 +960,6 @@ TEST(Reductions, ReduceSplitMask) { // Split a reduction axis cleanly not requiring a mask. TEST(Reductions, ReduceSplitNoMask) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -1021,7 +972,7 @@ TEST(Reductions, ReduceSplitNoMask) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 5); @@ -1039,8 +990,6 @@ TEST(Reductions, ReduceSplitNoMask) { // Split a reduction axis with all logic in the mask. TEST(Reductions, ReduceOverSplitMask) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -1054,7 +1003,7 @@ TEST(Reductions, ReduceOverSplitMask) { for (int i = 0; i < 3; ++i) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithMask(loops[i], 16); @@ -1073,8 +1022,6 @@ TEST(Reductions, ReduceOverSplitMask) { // Test an rfactor when there are two ReduceOps in the graph due to a // splitWithTail. TEST(Reductions, ReduceSplitRfactor) { - KernelScope kernel_scope; - const int M = 2; const int N = 10; const int K = 10; @@ -1090,16 +1037,16 @@ TEST(Reductions, ReduceSplitRfactor) { std::vector out(M, -1.f); - Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {{M, "m"}}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); - auto c_body = loop.getAllWritesToBuf(c->buf())[2]; - auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); + auto c_body = loop.getAllWritesToBuf(c.buf())[2]; + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); - all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); + all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); loop.prepareForCodegen(); @@ -1117,8 +1064,6 @@ TEST(Reductions, ReduceSplitRfactor) { // Test an rfactor which ends up being eliminated since the total loop size is // smaller than the split factor. TEST(Reductions, ReduceOverSplitRfactor) { - KernelScope kernel_scope; - const int N = 10; const int K = 10; const int SPLIT_FACTOR = 16; @@ -1131,7 +1076,7 @@ TEST(Reductions, ReduceOverSplitRfactor) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{N, "n"}, {K, "k"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -1139,9 +1084,9 @@ TEST(Reductions, ReduceOverSplitRfactor) { LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); LoopNest::reorderAxis(loops[0], i); - auto all_loops = loop.getAllLoopNestsWritingToBuf(c->buf()); + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); @@ -1174,7 +1119,6 @@ TEST(Reductions, ReduceOverSplitRfactor) { } TEST(Reductions, ReduceInlineReduction) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -1182,9 +1126,9 @@ TEST(Reductions, ReduceInlineReduction) { Placeholder a_buf("a", kFloat, {M}); Placeholder b_buf("b", kFloat, {M, N, K}); - Tensor* x = Reduce("x", {{M, "m1"}}, Sum(), b_buf, {{N, "n1"}, {K, "k1"}}); - Tensor* y = Compute("y", {{M, "m2"}}, [&](const VarHandle& m) { - return a_buf.load(m) + x->load(m); + Tensor x = Reduce("x", {{M, "m1"}}, Sum(), b_buf, {{N, "n1"}, {K, "k1"}}); + Tensor y = Compute("y", {{M, "m2"}}, [&](const VarHandle& m) { + return a_buf.load(m) + x.load(m); }); PaddedBuffer a_v(M); @@ -1203,11 +1147,10 @@ TEST(Reductions, ReduceInlineReduction) { LoopNest l1({y}, {x, y}); // Cannot inline a reduction computation - ASSERT_FALSE(l1.computeInline(x->buf())); + ASSERT_FALSE(l1.computeInline(x.buf())); } TEST(Reductions, ReduceInlineConsumer) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -1215,13 +1158,13 @@ TEST(Reductions, ReduceInlineConsumer) { Placeholder a_buf("a", kFloat, {M, N, K}); Placeholder b_buf("b", kFloat, {M, N, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return a_buf.load(m, n, k) + b_buf.load(m, n, k); }); - Tensor* y = Reduce("y", {{M, "m2"}}, Sum(), x, {{N, "n2"}, {K, "k2"}}); + Tensor y = Reduce("y", {{M, "m2"}}, Sum(), x, {{N, "n2"}, {K, "k2"}}); PaddedBuffer a_v(M, N, K); PaddedBuffer b_v(M, N, K); @@ -1237,7 +1180,7 @@ TEST(Reductions, ReduceInlineConsumer) { LoopNest l1({y}, {x, y}); LoopNest l2(l1); - l2.computeInline(x->buf()); + l2.computeInline(x.buf()); l1.prepareForCodegen(); l2.prepareForCodegen(); @@ -1261,7 +1204,6 @@ TEST(Reductions, ReduceInlineConsumer) { } TEST(Reductions, ReduceInlineReducerInternal) { - KernelScope kernel_scope; const int M = 4; const int N = 5; const int K = 6; @@ -1269,7 +1211,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { Placeholder a_buf("a", kFloat, {M, N, K}); Placeholder b_buf("b", kFloat, {M, N, K}); - Tensor* x = Compute( + Tensor x = Compute( "x", {{M, "m1"}, {N, "n1"}, {K, "k1"}}, [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { @@ -1279,7 +1221,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { return Add::make(ExprHandle(1.f), Min::make(a, b, false)); }); - Tensor* y = Reduce("y", {{M, "m2"}}, minimum, x, {{N, "n2"}, {K, "k2"}}); + Tensor y = Reduce("y", {{M, "m2"}}, minimum, x, {{N, "n2"}, {K, "k2"}}); PaddedBuffer a_v(M, N, K); PaddedBuffer b_v(M, N, K); @@ -1295,7 +1237,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { LoopNest l1({y}, {x, y}); LoopNest l2(l1); - l2.computeInline(x->buf()); + l2.computeInline(x.buf()); l1.prepareForCodegen(); l2.prepareForCodegen(); @@ -1319,8 +1261,6 @@ TEST(Reductions, ReduceInlineReducerInternal) { } TEST(Reductions, ReductionCacheAccessesOperatorAxis) { - KernelScope kernel_scope; - int L = 4; int N = 3; int M = 2; @@ -1328,16 +1268,16 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) { Placeholder a(BufHandle("a", {L, N, M}, kFloat)); Placeholder b(BufHandle("b", {L, N, M}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{L, "l2"}, {N, "n1"}, {M, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); - Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1346,7 +1286,7 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) { SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; - l.cacheAccesses(d->buf(), "d_local", d_loop); + l.cacheAccesses(d.buf(), "d_local", d_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1396,8 +1336,6 @@ TEST(Reductions, ReductionCacheAccessesOperatorAxis) { } TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { - KernelScope kernel_scope; - int L = 4; int N = 3; int M = 2; @@ -1405,16 +1343,16 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { Placeholder a(BufHandle("a", {L, N, M}, kFloat)); Placeholder b(BufHandle("b", {L, N, M}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{L, "l2"}, {N, "n1"}, {M, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); - Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1423,7 +1361,7 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(d->buf(), "d_local", d_loop); + l.cacheAccesses(d.buf(), "d_local", d_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1471,8 +1409,6 @@ TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { } TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { - KernelScope kernel_scope; - int L = 4; int N = 3; int M = 2; @@ -1480,16 +1416,16 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { Placeholder a(BufHandle("a", {L, N, M}, kFloat)); Placeholder b(BufHandle("b", {L, N, M}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{L, "l2"}, {N, "n1"}, {M, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); + Tensor d = Reduce("sum", {{L, "l1"}}, Sum(), c, {{N, "n1"}, {M, "m1"}}); - Tensor* e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{L, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1498,7 +1434,7 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; - l.cacheAccesses(d->buf(), "d_local", d_loop); + l.cacheAccesses(d.buf(), "d_local", d_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1546,27 +1482,25 @@ TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { } TEST(Reductions, ReductionCacheBodyAccess) { - KernelScope kernel_scope; - Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{24, "l2"}, {32, "n1"}, {12, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); - Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; - l.cacheAccesses(c->buf(), "scale_local", d_loop); + l.cacheAccesses(c.buf(), "scale_local", d_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1578,8 +1512,8 @@ TEST(Reductions, ReductionCacheBodyAccess) { #CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] #CHECK: for (int j = 0; j < 32; j++) { #CHECK: for (int k = 0; k < 12; k++) { -#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j]; -#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]); +#CHECK: scale_local[k + 12 * j] = scale[(k + 12 * j) + 384 * l1]; +#CHECK: sum[l1] = (sum[l1]) + (scale_local[m1_1 + 12 * n1_1]); #CHECK: scale_1[l] = (b[l]) * (sum[l]); #CHECK: Free(scale_local); )IR"; @@ -1587,21 +1521,19 @@ TEST(Reductions, ReductionCacheBodyAccess) { } TEST(Reductions, ReductionCacheConsumerAccess) { - KernelScope kernel_scope; - Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{24, "l2"}, {32, "n1"}, {12, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); - Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1609,7 +1541,7 @@ TEST(Reductions, ReductionCacheConsumerAccess) { LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; - l.cacheAccesses(d->buf(), "sum_local", e_loop); + l.cacheAccesses(d.buf(), "sum_local", e_loop); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1628,21 +1560,19 @@ TEST(Reductions, ReductionCacheConsumerAccess) { } TEST(Reductions, ReductionSplitCacheConsumerAccess) { - KernelScope kernel_scope; - Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{24, "l2"}, {32, "n1"}, {12, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); - Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1656,7 +1586,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { // Split reduction consumer. LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - l.cacheAccesses(d->buf(), "sum_local", inner); + l.cacheAccesses(d.buf(), "sum_local", inner); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1667,7 +1597,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { const std::string& expected_ir = R"IR( #CHECK: Allocate(sum_local); // dtype=float, dims=[4] -#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]); +#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((m1_1 + 12 * n1_1) + 1536 * l1_outer) + 384 * l1_inner]); #CHECK: for (int i = 0; i < 4 #CHECK: sum_local[i] = sum[i + 4 * l_outer]; #CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); @@ -1676,21 +1606,19 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) { } TEST(Reductions, ReductionReorderCacheConsumerAccess) { - KernelScope kernel_scope; - Placeholder a(BufHandle("a", {24, 32, 12}, kFloat)); Placeholder b(BufHandle("b", {24, 32, 12}, kFloat)); - Tensor* c = Compute( + Tensor c = Compute( "scale", {{24, "l2"}, {32, "n1"}, {12, "m1"}}, [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { return b.load(l, n, m) * a.load(l, n, m); }); - Tensor* d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); + Tensor d = Reduce("sum", {{24, "l1"}}, Sum(), c, {{32, "n1"}, {12, "m1"}}); - Tensor* e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { - return b.load(0, 0, l) * d->load(l); + Tensor e = Compute("scale", {{24, "l"}}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); }); LoopNest l({e}, {c, d, e}); @@ -1705,7 +1633,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { // Split reduction consumer. LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); - l.cacheAccesses(d->buf(), "sum_local", inner); + l.cacheAccesses(d.buf(), "sum_local", inner); l.prepareForCodegen(); StmtPtr result = IRSimplifier::simplify(l.root_stmt()); @@ -1716,7 +1644,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { const std::string& expected_ir = R"IR( #CHECK: Allocate(sum_local); // dtype=float, dims=[4] -#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]); +#CHECK: sum[l1] = (sum[l1]) + (scale[(m1_1 + 12 * n1_1) + 384 * l1]); #CHECK: for (int i = 0; i < 4 #CHECK: sum_local[i] = sum[i + 4 * l_outer]; #CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]); @@ -1725,8 +1653,6 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) { } TEST(Reductions, ReductionRfactorCacheTempOuter) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -1742,13 +1668,13 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) BufPtr rfac_buf; ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); @@ -1786,7 +1712,6 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { #CHECK-NOT: tmp )IR"; torch::jit::testing::FileCheck().run(expected_ir, oss.str()); - SimpleIREvaluator cg(s, {b, c, m, n, k}); cg.call({in, out, M, N, K}); @@ -1794,8 +1719,6 @@ TEST(Reductions, ReductionRfactorCacheTempOuter) { } TEST(Reductions, ReductionRfactorCacheTempInner) { - KernelScope kernel_scope; - const int M = 10; const int N = 10; const int K = 10; @@ -1811,10 +1734,10 @@ TEST(Reductions, ReductionRfactorCacheTempInner) { std::vector out(1, -1.f); - Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); + Tensor c = Reduce("sum", {}, Sum(), b, {{m, "a"}, {n, "b"}, {k, "c"}}); LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); - auto c_body = loop.getAllWritesToBuf(c->buf())[1]; + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; LoopNest::reorderAxis(loops.at(0), loops.at(1)); loops = loop.getLoopStmtsFor(c); @@ -1858,8 +1781,6 @@ TEST(Reductions, ReductionRfactorCacheTempInner) { } TEST(Reductions, ReductionVectorize) { - KernelScope kernel_scope; - std::vector in_(8 * 8); for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; ++j) { @@ -1871,7 +1792,7 @@ TEST(Reductions, ReductionVectorize) { Placeholder in(BufHandle("in", {8, 8}, kFloat)); - Tensor* tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); + Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l_before({tensor}); LoopNest l(l_before); l_before.prepareForCodegen(); @@ -1905,19 +1826,15 @@ TEST(Reductions, ReductionVectorize) { } TEST(Reductions, ReductionVectorizeInner) { - KernelScope kernel_scope; - Placeholder in(BufHandle("in", {8, 8}, kFloat)); - Tensor* tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); + Tensor tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l({tensor}); ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); } TEST(Reductions, ReductionVectorizeRfactor) { - KernelScope kernel_scope; - std::vector in_(8 * 8); for (int i = 0; i < 8; ++i) { for (int j = 0; j < 8; ++j) { @@ -1929,7 +1846,7 @@ TEST(Reductions, ReductionVectorizeRfactor) { Placeholder in(BufHandle("in", {8, 8}, kFloat)); - Tensor* tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}}); + Tensor tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}}); LoopNest l_before({tensor}); LoopNest l(l_before); @@ -1944,7 +1861,7 @@ TEST(Reductions, ReductionVectorizeRfactor) { std::vector loops = l.getLoopStmtsFor(tensor); LoopNest::reorderAxis(loops[0], loops[1]); loops = l.getLoopStmtsFor(tensor); - auto tensor_body = l.getAllWritesToBuf(tensor->buf())[1]; + auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; BufPtr rfac_buf = nullptr; ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); @@ -1983,12 +1900,11 @@ TEST(Reductions, ReductionVectorizeRfactor) { } TEST(Reductions, InitFunction) { - KernelScope ks; constexpr int M = 32; constexpr int N = 16; Placeholder A("A", kFloat, {M, N}); Placeholder B("B", kFloat, {N}); - Tensor* C = Reduce( + Tensor C = Reduce( "C", {{N, "n"}}, Sum(), diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index a0ac095db757f..1338b6d19c929 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -13,7 +13,6 @@ using namespace torch::jit::tensorexpr; // Can replace a simple scalar access with a local variable. TEST(Registerizer, RegisterizerSimple) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -58,7 +57,6 @@ TEST(Registerizer, RegisterizerSimple) { // Won't do replacement of a loop access. TEST(Registerizer, RegisterizerLoop) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -105,7 +103,6 @@ TEST(Registerizer, RegisterizerLoop) { // Won't replace even if the load is a fixed scalar, since the store could // invalidate it. TEST(Registerizer, RegisterizerLoopFixedLoad) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -152,7 +149,6 @@ TEST(Registerizer, RegisterizerLoopFixedLoad) { // We can registerize accesses that occur entirely within inner scopes, even if // they depend on the loop var. TEST(Registerizer, RegisterizerLoopInternal) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({For::make( @@ -192,8 +188,8 @@ TEST(Registerizer, RegisterizerLoopInternal) { R"IR( # CHECK: for (int x = 0; x < 10; x++) # CHECK: int A_1 = A[x]; -# CHECK: A_1 = x + A_1; -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; +# CHECK: A_1 = A_1 + x; # CHECK: A[x] = A_1; # CHECK: })IR"; @@ -203,7 +199,6 @@ TEST(Registerizer, RegisterizerLoopInternal) { // An access can be overlapped by another read in the same Expr. In this case // B[z] and B[y] overlap and prevent registerization of both accesses. TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -235,7 +230,6 @@ TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { } TEST(Registerizer, RegisterizerLoopInternalRepeated) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -273,12 +267,12 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) { * int A_1 = A[1]; * int A_2 = A[0]; * for (int x = 0; x < 10; x++) { - * A_2 = x + A_1; - * A_2 = x + A_1; + * A_2 = A_1 + x; + * A_2 = A_1 + x; * } * for (int x = 0; x < 10; x++) { - * A_2 = x + A_1; - * A_2 = x + A_1; + * A_2 = A_1 + x; + * A_2 = A_1 + x; * } * A[0] = A_2; */ @@ -291,12 +285,12 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) { # CHECK: int A_1 = A[1]; # CHECK: int A_2 = A[0]; # CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = x + A_1; -# CHECK: A_2 = x + A_1; +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; # CHECK: } # CHECK: for (int x = 0; x < 10; x++) -# CHECK: A_2 = x + A_1; -# CHECK: A_2 = x + A_1; +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; # CHECK: } # CHECK-NOT: A[1] # CHECK: A[0] = A_2; @@ -307,7 +301,6 @@ TEST(Registerizer, RegisterizerLoopInternalRepeated) { } TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -353,11 +346,10 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { } TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - StmtPtr stmt = Block::make( + StmtPtr stmt = IRSimplifier::simplify(Block::make( {For::make( x, 0, @@ -373,7 +365,7 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) - }); + })); /* * for (int x = 0; x < 10; x++) { @@ -400,7 +392,6 @@ TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { // Will registerize multiple accesses of different items of the same buffer. TEST(Registerizer, RegisterizerMultiVar) { - KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({ @@ -456,7 +447,6 @@ TEST(Registerizer, RegisterizerMultiVar) { // Will registerize the valid accesses while skipping invalid replacements. TEST(Registerizer, RegisterizerVariableLoad) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -512,7 +502,6 @@ TEST(Registerizer, RegisterizerVariableLoad) { // Can registerize variable accesses so long as the variable does not change. TEST(Registerizer, RegisterizerSymbolicIndices) { - KernelScope kernel_scope; VarHandle i("i", kInt); VarHandle N("N", kInt); BufHandle a("A", {N}, kInt); @@ -559,7 +548,6 @@ TEST(Registerizer, RegisterizerSymbolicIndices) { // Can registerize accesses dependent on multiple loop vars. TEST(Registerizer, RegisterizerMultiLoop) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -616,7 +604,6 @@ TEST(Registerizer, RegisterizerMultiLoop) { // Can registerize correctly if scalars already exist in the program. TEST(Registerizer, RegisterizerRepeated) { - KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({ @@ -673,7 +660,6 @@ TEST(Registerizer, RegisterizerRepeated) { // Can registerize the load of A. TEST(Registerizer, RegisterizerNoLoads) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -714,7 +700,6 @@ TEST(Registerizer, RegisterizerNoLoads) { // Can registerize the load of A but not the store of B. TEST(Registerizer, RegisterizerNoRepeatedStores) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -763,7 +748,6 @@ TEST(Registerizer, RegisterizerNoRepeatedStores) { // Won't registerize if there are multiple accesses which may overlap. TEST(Registerizer, RegisterizerMultiVarOverlap) { - KernelScope kernel_scope; BufHandle a("A", {2}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({ @@ -792,8 +776,6 @@ TEST(Registerizer, RegisterizerMultiVarOverlap) { } TEST(Registerizer, RegisterizerAllocs) { - KernelScope kernel_scope; - BufHandle a("A", {2}, kInt); BufHandle c("C", {1}, kInt); VarHandle x("x", kInt); @@ -860,7 +842,6 @@ TEST(Registerizer, RegisterizerAllocs) { } TEST(Registerizer, RegisterizerNoInitializer) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({For::make( @@ -900,7 +881,6 @@ TEST(Registerizer, RegisterizerNoInitializer) { } TEST(Registerizer, RegisterizerNoInitializerLoopVar) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({For::make( @@ -929,7 +909,6 @@ TEST(Registerizer, RegisterizerNoInitializerLoopVar) { } TEST(Registerizer, RegisterizerLoadThenStore) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); BufHandle b("B", {1}, kInt); VarHandle x("x", kInt); @@ -980,7 +959,6 @@ TEST(Registerizer, RegisterizerLoadThenStore) { } TEST(Registerizer, RegisterizerParallelized) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); LoopOptions loopOpts; @@ -1009,7 +987,6 @@ TEST(Registerizer, RegisterizerParallelized) { // Should be able to registerize this since the scalar would exist before the // branch. TEST(Registerizer, RegisterizerConditionAfter) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1059,7 +1036,6 @@ TEST(Registerizer, RegisterizerConditionAfter) { // Should be able to registerize this since the scalar exists in the same form // after the branch and there is no overlap. TEST(Registerizer, RegisterizerConditionBefore) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1111,7 +1087,6 @@ TEST(Registerizer, RegisterizerConditionBefore) { // Should be able to registerize this as the combination of the two above rules. TEST(Registerizer, RegisterizerConditionInside) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1171,7 +1146,6 @@ TEST(Registerizer, RegisterizerConditionInside) { // condition, and both sides are large enough to be registerized but cannot be // because there is no safe place to put the initializer or finalizer. TEST(Registerizer, RegisterizerConditionInsideOverlap1) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1231,7 +1205,6 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap1) { // the condition, and the first group must be finalized before the Cond, the // second initialized after it. TEST(Registerizer, RegisterizerConditionInsideOverlap2) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1317,7 +1290,6 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap2) { // the accesses in it don't need to be valid (think size checks on the index). // In this case the accesses cannot be registerized. TEST(Registerizer, RegisterizerConditionHidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1359,7 +1331,6 @@ TEST(Registerizer, RegisterizerConditionHidden) { // the user's fault). It "unhides" the conditional accesses, allowing // registerization to occur. TEST(Registerizer, RegisterizerConditionUnhidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1420,7 +1391,6 @@ TEST(Registerizer, RegisterizerConditionUnhidden) { // Can registerize a load that occurs in the condition of a Cond. TEST(Registerizer, RegisterizerCondCondition) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1471,7 +1441,6 @@ TEST(Registerizer, RegisterizerCondCondition) { // Appearing in the condition of a Cond makes it visible to the enclosing scope, // and so we can registerize internal usages. TEST(Registerizer, RegisterizerCondConditionUnhidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1520,7 +1489,6 @@ TEST(Registerizer, RegisterizerCondConditionUnhidden) { // Conditional hiding also works for IfThenElse exprs. TEST(Registerizer, RegisterizerIfThenElseHidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1562,7 +1530,6 @@ TEST(Registerizer, RegisterizerIfThenElseHidden) { // Conditional unhiding also works for IfThenElse exprs. TEST(Registerizer, RegisterizerIfThenElseUnhidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1617,7 +1584,6 @@ TEST(Registerizer, RegisterizerIfThenElseUnhidden) { // Nested IfThenElse exprs can't promote to higher level scopes. TEST(Registerizer, RegisterizerIfThenElseNested) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1661,7 +1627,6 @@ TEST(Registerizer, RegisterizerIfThenElseNested) { // to check that we don't promote the initializer/finalizer to the enclosing // Block. TEST(Registerizer, RegisterizerIfThenElseInternal) { - KernelScope kernel_scope; // Making these floats so they don't get simplified to a single access. BufHandle a("A", {5}, kFloat); BufHandle b("B", {5}, kFloat); @@ -1740,7 +1705,6 @@ TEST(Registerizer, RegisterizerIfThenElseInternal) { // Can registerize a load that occurs in the condition of an IfThenElse; TEST(Registerizer, RegisterizerIfThenElseCondition) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1786,7 +1750,6 @@ TEST(Registerizer, RegisterizerIfThenElseCondition) { // Appearing in the condition of a Cond makes it visible to the enclosing scope, // and so we can registerize internal usages. TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1826,7 +1789,6 @@ TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { // Cannot promote accesses internal to IfThenElse branches even if the enclosing // scope if conditional. TEST(Registerizer, RegisterizerConditionBranchOnly) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({For::make( @@ -1877,7 +1839,6 @@ TEST(Registerizer, RegisterizerConditionBranchOnly) { // We can registerize an IfThenElse that appears in the condition branch of a // Cond. This is a weird but valid thing to do. TEST(Registerizer, RegisterizerCondIfThenElse) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); BufHandle c("C", {5}, kInt); @@ -1927,7 +1888,6 @@ TEST(Registerizer, RegisterizerCondIfThenElse) { // Can registerize a conditional access in the RHS of a store unhidden by it's // LHS, and hoist it out of a loop. TEST(Registerizer, RegisterizerIfThenElseLoop) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); VarHandle x("x", kInt); @@ -1977,7 +1937,6 @@ TEST(Registerizer, RegisterizerIfThenElseLoop) { // Cannot registerize if the RHS overlaps the access creating visibility. TEST(Registerizer, RegisterizerIfThenElseLoopCut) { - KernelScope kernel_scope; BufHandle a("A", {5}, kInt); BufHandle b("B", {5}, kInt); VarHandle x("x", kInt); @@ -2016,7 +1975,6 @@ TEST(Registerizer, RegisterizerIfThenElseLoopCut) { // Simple case where an access is cut by an overlapping access later in the // program, we can registerize up until the overlap. TEST(Registerizer, RegisterizerPartialAfter) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2044,7 +2002,7 @@ TEST(Registerizer, RegisterizerPartialAfter) { /* * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; + * A_1 = A_1 + x; * } * A[0] = A_1; * for (int x = 1; x < 10; x++) { @@ -2059,7 +2017,7 @@ TEST(Registerizer, RegisterizerPartialAfter) { R"IR( # CHECK: int A_1 = 0; # CHECK: for ( -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; # CHECK: } # CHECK: A[0] = A_1; # CHECK: for ( @@ -2073,7 +2031,6 @@ TEST(Registerizer, RegisterizerPartialAfter) { // We can registerize an access which overlaps a previous access, the // initializer must be inserted after the previous access. TEST(Registerizer, RegisterizerPartialBefore) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2104,7 +2061,7 @@ TEST(Registerizer, RegisterizerPartialBefore) { * } * int A_1 = 0; * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; + * A_1 = A_1 + x; * } * A[0] = A_1; */ @@ -2120,7 +2077,7 @@ TEST(Registerizer, RegisterizerPartialBefore) { # CHECK: } # CHECK: int A_1 = 0; # CHECK: for ( -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; # CHECK: } # CHECK: A[0] = A_1;)IR"; @@ -2130,7 +2087,6 @@ TEST(Registerizer, RegisterizerPartialBefore) { // The combination of the previous two tests, an access is cut by an overlapping // access in both directions. TEST(Registerizer, RegisterizerPartialInside) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x1("x1", kInt); VarHandle x2("x2", kInt); @@ -2161,7 +2117,7 @@ TEST(Registerizer, RegisterizerPartialInside) { /* * int A_1 = 2; * for (int x1 = 0; x1 < 10; x1++) { - * A_1 = x1 + A_1; + * A_1 = A_1 + x1; * } * A[0] = A_1; * for (int x2 = 1; x2 < 10; x2++) { @@ -2169,7 +2125,7 @@ TEST(Registerizer, RegisterizerPartialInside) { * } * int A_2 = A[0]; * for (int x3 = 0; x3 < 10; x3++) { - * A_2 = x3 + A_2; + * A_2 = A_2 + x3; * } * A[0] = A_2; */ @@ -2181,7 +2137,7 @@ TEST(Registerizer, RegisterizerPartialInside) { R"IR( # CHECK: int A_1 = 2; # CHECK: for ( -# CHECK: A_1 = x1 + A_1; +# CHECK: A_1 = A_1 + x1; # CHECK: } # CHECK: A[0] = A_1; # CHECK: for ( @@ -2189,7 +2145,7 @@ TEST(Registerizer, RegisterizerPartialInside) { # CHECK: } # CHECK: int A_2 = A[0]; # CHECK: for ( -# CHECK: A_2 = x3 + A_2; +# CHECK: A_2 = A_2 + x3; # CHECK: } # CHECK: A[0] = A_2;)IR"; @@ -2200,7 +2156,6 @@ TEST(Registerizer, RegisterizerPartialInside) { // access, we should break this into two scalars and write back to the buffer // before the condition. TEST(Registerizer, RegisterizerPartialCondition) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2232,7 +2187,7 @@ TEST(Registerizer, RegisterizerPartialCondition) { /* * int A_1 = 2; * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; + * A_1 = A_1 + x; * } * A[0] = A_1; * if (x<5 ? 1 : 0) { @@ -2240,7 +2195,7 @@ TEST(Registerizer, RegisterizerPartialCondition) { * } * int A_2 = A[0]; * for (int x = 0; x < 10; x++) { - * A_2 = x + A_2; + * A_2 = A_2 + x; * } * A[0] = A_2; */ @@ -2252,7 +2207,7 @@ TEST(Registerizer, RegisterizerPartialCondition) { R"IR( # CHECK: int A_1 = 2; # CHECK: for ( -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; # CHECK: } # CHECK: A[0] = A_1; # CHECK: if ( @@ -2260,7 +2215,7 @@ TEST(Registerizer, RegisterizerPartialCondition) { # CHECK: } # CHECK: int A_2 = A[0]; # CHECK: for ( -# CHECK: A_2 = x + A_2; +# CHECK: A_2 = A_2 + x; # CHECK: } # CHECK: A[0] = A_2;)IR"; @@ -2270,7 +2225,6 @@ TEST(Registerizer, RegisterizerPartialCondition) { // Tests case where an access is cut by an internal conditional access which // itself is registerized. TEST(Registerizer, RegisterizerPartialConditionInternalCut) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2333,7 +2287,6 @@ TEST(Registerizer, RegisterizerPartialConditionInternalCut) { // First statment in condition closes outer access, but can be registerized with // later statements. TEST(Registerizer, RegisterizerPartialConditionInternalStart) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2397,7 +2350,6 @@ TEST(Registerizer, RegisterizerPartialConditionInternalStart) { // An access cuts two open overlaps and creates four scalar variables. TEST(Registerizer, RegisterizerPartialOverlapsTwo) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2468,7 +2420,6 @@ TEST(Registerizer, RegisterizerPartialOverlapsTwo) { // Nested blocks will automatically be flattened and do not provent // registerization of enclosed accesses. TEST(Registerizer, RegisterizerNestedBlocks) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2522,7 +2473,6 @@ TEST(Registerizer, RegisterizerNestedBlocks) { // The access can be registerized internally to a condition, but must ensure // that both initializer and finalizer are within the same condition. TEST(Registerizer, RegisterizerNestedConditions) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make({Cond::make( @@ -2578,7 +2528,6 @@ TEST(Registerizer, RegisterizerNestedConditions) { // If an access exists outside the scope of the condition then we can lift // nested conditional usages into the same scalar. TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2634,7 +2583,6 @@ TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { } TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2677,7 +2625,6 @@ TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { } TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2722,7 +2669,6 @@ TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { // If an access is cut by another access internal to a condition block, it still // cuts the access. TEST(Registerizer, RegisterizerNestedConditionsCut) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -2761,7 +2707,6 @@ TEST(Registerizer, RegisterizerNestedConditionsCut) { } TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -2808,7 +2753,6 @@ TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { // Three loops and four element regions, three of which should be registerized // at different levels of the IR. TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -2908,7 +2852,6 @@ TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { // Can replace a simple scalar access with a local variable even when that // variable is an outer loop var. TEST(Registerizer, RegisterizerNestedLoopSimple) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -2937,7 +2880,7 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) { * for (int y = 0; y < 10; y++) { * int A_1 = A[y]; * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; + * A_1 = A_1 + x; * } * A[y] = A_1; * } @@ -2951,7 +2894,7 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) { # CHECK: for (int y # CHECK: int A_1 = A[y]; # CHECK: for (int x -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; # CHECK: } # CHECK: A[y] = A_1; # CHECK: })IR"; @@ -2963,7 +2906,6 @@ TEST(Registerizer, RegisterizerNestedLoopSimple) { // conditional access can be hoisted up through a loop to match an existing // access in a higher scope and the two can be registerized. TEST(Registerizer, RegisterizerHiddenAccessYes) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -3046,7 +2988,6 @@ TEST(Registerizer, RegisterizerHiddenAccessYes) { // never unhidden at a higher scope and registerization occurs at the lower // scope. TEST(Registerizer, RegisterizerHiddenAccessNo) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -3126,7 +3067,6 @@ TEST(Registerizer, RegisterizerHiddenAccessNo) { // two accesses here one is unhidden and the other isnt. A[0] can be // registerized but B[0] cannot. TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); VarHandle x("x", kInt); @@ -3208,7 +3148,6 @@ TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { // Accesses are registerized inside two conditions, but the immeidate parent is // not a condition. TEST(Registerizer, RegisterizerTwoConditionalLoops) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -3280,7 +3219,6 @@ TEST(Registerizer, RegisterizerTwoConditionalLoops) { // Accesses are registerized inside two conditions, cut in the middle. TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { - KernelScope kernel_scope; BufHandle a("A", {1}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -3362,17 +3300,16 @@ TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { // references a Let var in a local scope which cannot be hoisted out of the // loop. TEST(Registerizer, RegisterizerLoopLetVar) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); - StmtPtr stmt = Block::make({For::make( + StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( x, 0, 10, Block::make( {Let::make(y, 30), - Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); + Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); /* * for (int x = 0; x < 10; x++) { @@ -3396,7 +3333,6 @@ TEST(Registerizer, RegisterizerLoopLetVar) { // references a Let var in an outer scope that does not prevent hoisting the // initializer. TEST(Registerizer, RegisterizerLoopLetVarOuter) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -3422,7 +3358,7 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) { * int y = 30; * int A_1 = A[y]; * for (int x = 0; x < 10; x++) { - * A_1 = x + A_1; + * A_1 = A_1 + x; * } * A[y] = A_1; */ @@ -3435,7 +3371,7 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) { # CHECK: int y = 30; # CHECK: int A_1 = A[y]; # CHECK: for (int x -# CHECK: A_1 = x + A_1; +# CHECK: A_1 = A_1 + x; # CHECK: A[y] = A_1;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -3444,7 +3380,6 @@ TEST(Registerizer, RegisterizerLoopLetVarOuter) { // Okay so the registerizer generally goes after index flattening, but just in // case. Test multi index registerization. TEST(Registerizer, RegisterizerMultiDim) { - KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -3490,7 +3425,6 @@ TEST(Registerizer, RegisterizerMultiDim) { // Wont registerize if only some dims match, but will still registerize distinct // elements. TEST(Registerizer, RegisterizerMultiDimPartial) { - KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); StmtPtr stmt = Block::make( @@ -3516,7 +3450,7 @@ TEST(Registerizer, RegisterizerMultiDimPartial) { * int A_1 = A[0, 1, 4]; * int A_2 = A[0, 2, 2]; * for (int x = 0; x < 10; x++) { - * A_2 = x + A_1; + * A_2 = A_1 + x; * } * A[0, 2, 2] = A_2; */ @@ -3530,7 +3464,7 @@ TEST(Registerizer, RegisterizerMultiDimPartial) { # CHECK: int A_1 = A[0, 1, 4]; # CHECK: int A_2 = A[0, 2, 2]; # CHECK: for ( -# CHECK: A_2 = x + A_1; +# CHECK: A_2 = A_1 + x; # CHECK: A[0, 2, 2] = A_2;)IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -3538,7 +3472,6 @@ TEST(Registerizer, RegisterizerMultiDimPartial) { // If they could overlap across all dimensions we cannot registerize. TEST(Registerizer, RegisterizerMultiDimOverlap) { - KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -3573,7 +3506,6 @@ TEST(Registerizer, RegisterizerMultiDimOverlap) { // But, if one dimension is known to be distinct they do not overlap. TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { - KernelScope kernel_scope; BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -3599,7 +3531,7 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { * A[0, 1, 2] = 0; * int A_1 = A[y, 2, 4]; * for (int x = 0; x < 10; x++) { - * A[0, x, 2] = x + A_1; + * A[0, x, 2] = A_1 + x; * } */ @@ -3611,7 +3543,7 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { # CHECK: A[0, 1, 2] = 0; # CHECK: int A_1 = A[y, 2, 4]; # CHECK: for ( -# CHECK: A[0, x, 2] = x + A_1; +# CHECK: A[0, x, 2] = A_1 + x; # CHECK: })IR"; torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -3619,7 +3551,6 @@ TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { // A 3D reduction with different input dimensionality. TEST(Registerizer, RegisterizerMultiDim3DReduction1) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10, 10}, kInt); BufHandle c("C", {10, 10, 10}, kInt); @@ -3691,7 +3622,6 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction1) { // A 3D reduction with the same smaller dimensionality using different loop // vars. TEST(Registerizer, RegisterizerMultiDim3DReduction2) { - KernelScope kernel_scope; BufHandle a("A", {10}, kInt); BufHandle b("B", {10}, kInt); BufHandle c("C", {10}, kInt); @@ -3736,12 +3666,12 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction2) { /* * for (int x = 0; x < 10; x++) { - * int C_1 = C[x]; * int A_1 = A[x]; + * int C_1 = C[x]; * for (int y = 0; y < 10; y++) { * int B_1 = B[y]; * for (int z = 0; z < 10; z++) { - * C_1 = C_1 + A_1 * B_1; + * C_1 = A_1 * B_1 + C_1; * } * } * C[x] = C_1; @@ -3754,12 +3684,12 @@ TEST(Registerizer, RegisterizerMultiDim3DReduction2) { const std::string& verification_pattern = R"IR( # CHECK: for (int x -# CHECK: int C_1 = C[x]; # CHECK: int A_1 = A[x]; +# CHECK: int C_1 = C[x]; # CHECK: for (int y # CHECK: int B_1 = B[y]; # CHECK: for (int z -# CHECK: C_1 = C_1 + A_1 * B_1; +# CHECK: C_1 = A_1 * B_1 + C_1; # CHECK: } # CHECK: } # CHECK: C[x] = C_1; diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index a08d4ca974fd1..48983c8f4ba33 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -14,7 +14,6 @@ using namespace torch::jit::tensorexpr; using SimpleIRExprEval = ExprEval; TEST(Simplify, ConstantFoldSimple) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle f = (a + b); @@ -28,7 +27,6 @@ TEST(Simplify, ConstantFoldSimple) { } TEST(Simplify, ConstantFoldTwoLayer) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); @@ -44,7 +42,6 @@ TEST(Simplify, ConstantFoldTwoLayer) { } TEST(Simplify, ConstantFoldShifts) { - KernelScope kernel_scope; ExprHandle a(7); ExprHandle b(2); ExprHandle c(3); @@ -59,7 +56,6 @@ TEST(Simplify, ConstantFoldShifts) { } TEST(Simplify, ConstantFoldBitwise) { - KernelScope kernel_scope; ExprHandle a(59); ExprHandle b(22); ExprHandle c(101); @@ -74,7 +70,6 @@ TEST(Simplify, ConstantFoldBitwise) { } TEST(Simplify, ConstantFoldMultiOp) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); @@ -93,7 +88,6 @@ TEST(Simplify, ConstantFoldMultiOp) { } TEST(Simplify, ConstantFoldMinMax) { - KernelScope kernel_scope; ExprHandle a(12.0f); ExprHandle b(15.0f); ExprHandle c(17.0f); @@ -113,7 +107,6 @@ TEST(Simplify, ConstantFoldMinMax) { } TEST(Simplify, ConstantFoldIntrinsics) { - KernelScope kernel_scope; ExprHandle a(2.0f); ExprHandle b(3.0f); ExprHandle c(4.0f); @@ -135,7 +128,6 @@ TEST(Simplify, ConstantFoldIntrinsics) { } TEST(Simplify, ConstantFoldCastToBool) { - KernelScope kernel_scope; ExprHandle f = Cast::make(kBool, IntImm::make(0)); ExprHandle newF = IRSimplifier::simplify(f); SimpleIRExprEval eval(newF); @@ -143,7 +135,6 @@ TEST(Simplify, ConstantFoldCastToBool) { } TEST(Simplify, ConstantFoldWithVar) { - KernelScope kernel_scope; { VarHandle x("x", kInt); ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); @@ -174,7 +165,6 @@ TEST(Simplify, ConstantFoldWithVar) { } TEST(Simplify, ConditionalSelectFoldSimple) { - KernelScope kernel_scope; ExprHandle a(3.0f); ExprHandle b(4.0f); ExprHandle c(3.0f); @@ -221,7 +211,6 @@ TEST(Simplify, ConditionalSelectFoldSimple) { } TEST(Simplify, ConditionalSelectFoldTwoLayer) { - KernelScope kernel_scope; ExprHandle a(3.0f); ExprHandle b(2.0f); ExprHandle c(2.0f); @@ -269,7 +258,6 @@ TEST(Simplify, ConditionalSelectFoldTwoLayer) { } TEST(Simplify, ConditionalSelectFoldWithVar) { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle f = x < 4.f; @@ -290,7 +278,6 @@ TEST(Simplify, ConditionalSelectFoldWithVar) { } TEST(Simplify, UnFoldableExpr) { - KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); @@ -308,7 +295,6 @@ TEST(Simplify, UnFoldableExpr) { } TEST(Simplify, HashSimple) { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -329,7 +315,6 @@ TEST(Simplify, HashSimple) { } TEST(Simplify, HashEquivalence) { - KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle f = (x * y) + (x * y); @@ -366,7 +351,6 @@ TEST(Simplify, HashEquivalence) { } TEST(Simplify, HashEquivalenceRand) { - KernelScope kernel_scope; ExprHandle f = Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); @@ -386,7 +370,6 @@ TEST(Simplify, HashEquivalenceRand) { } TEST(Simplify, HashEquivalenceAfterFolding) { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle a(2.0f); ExprHandle b(3.0f); @@ -412,8 +395,6 @@ TEST(Simplify, HashEquivalenceAfterFolding) { } TEST(Simplify, HashDifferenceTypes) { - KernelScope kernel_scope; - HashProvider hasher; std::vector immediates; @@ -446,7 +427,6 @@ TEST(Simplify, HashDifferenceTypes) { } TEST(Simplify, HashLargeExpression) { - KernelScope kernel_scope; constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); @@ -490,7 +470,6 @@ TEST(Simplify, HashLargeExpression) { } TEST(Simplify, HashForLoopOptions) { - KernelScope kernel_scope; constexpr int N = 1024; BufHandle a("A", {N}, kInt); BufHandle b("B", {N}, kInt); @@ -532,7 +511,6 @@ TEST(Simplify, HashForLoopOptions) { /// (2 + x) + 4 => x + 6 TEST(Simplify, SimplifyAdd) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -558,7 +536,6 @@ TEST(Simplify, SimplifyAdd) { /// (2 - x) - 4 => -2 - x TEST(Simplify, SimplifySub) { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); @@ -575,7 +552,6 @@ TEST(Simplify, SimplifySub) { /// 2 * (1 - x) - 4 => 2 * (-3 - x) TEST(Simplify, SimplifyMultiLayer) { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); ExprHandle simplified = IRSimplifier::simplify(body); @@ -588,7 +564,6 @@ TEST(Simplify, SimplifyMultiLayer) { /// 2 * (3 * x) - (x * 4) => 2 * x TEST(Simplify, SimplifyMultiTerm) { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); @@ -606,7 +581,6 @@ TEST(Simplify, SimplifyMultiTerm) { /// 2 * (3 * (long)x) - (x * 4) => 2 * x TEST(Simplify, SimplifyCasts) { - KernelScope kernel_scope; VarHandle x("x", kLong); ExprHandle body = (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); @@ -624,7 +598,6 @@ TEST(Simplify, SimplifyCasts) { /// (x + 0) * 1 => x TEST(Simplify, SimplifyEliminatesNoOps) { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle body = (x + ExprHandle(0)) * 1; @@ -636,7 +609,6 @@ TEST(Simplify, SimplifyEliminatesNoOps) { /// Cannot simplify this. TEST(Simplify, SimplifyMultiVar) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = x * 24 + y * 34; @@ -649,17 +621,16 @@ TEST(Simplify, SimplifyMultiVar) { ASSERT_NE(lhs, nullptr); VarPtr varX = to(lhs->rhs()); ASSERT_NE(varX, nullptr); - ASSERT_EQ(varX->name_hint(), "y"); + ASSERT_EQ(varX->name_hint(), "x"); MulPtr rhs = to(root->rhs()); ASSERT_NE(rhs, nullptr); VarPtr varY = to(rhs->rhs()); ASSERT_NE(varY, nullptr); - ASSERT_EQ(varY->name_hint(), "x"); + ASSERT_EQ(varY->name_hint(), "y"); } // x + 2 + y => x + y + 2 TEST(Simplify, DISABLED_SimplifyReorderings) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = x + 2 + y; @@ -676,7 +647,6 @@ TEST(Simplify, DISABLED_SimplifyReorderings) { /// y + x * 0 => y TEST(Simplify, SimplifyEliminatesVar) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = y + x * ExprHandle(0); @@ -686,7 +656,6 @@ TEST(Simplify, SimplifyEliminatesVar) { } TEST(Simplify, SimplifyAdds) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -698,8 +667,8 @@ TEST(Simplify, SimplifyAdds) { IS_NODE_WITH_NAME(Mul, simplified.node(), root); IS_IMM_WITH_VAL(Int, root->lhs(), 2); IS_NODE_WITH_NAME(Add, root->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); } { @@ -757,7 +726,6 @@ TEST(Simplify, SimplifyAdds) { } TEST(Simplify, SimplifyMuls) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -770,11 +738,11 @@ TEST(Simplify, SimplifyMuls) { IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "y"); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); - IS_VAR_WITH_NAME(rhs->lhs(), "y"); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { @@ -867,8 +835,8 @@ TEST(Simplify, SimplifyMuls) { ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); - IS_VAR_WITH_NAME(lhs->lhs(), "y"); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); IS_VAR_WITH_NAME(rhs->lhs(), "x"); IS_VAR_WITH_NAME(rhs->rhs(), "y"); @@ -959,7 +927,6 @@ TEST(Simplify, SimplifyMuls) { // Sub an expr from itself will result in zero. TEST(Simplify, SimplifySubs) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1125,7 +1092,6 @@ TEST(Simplify, SimplifySubs) { } TEST(Simplify, SimplifyDiv) { - KernelScope kernel_scope; VarHandle x("x", kInt); { @@ -1144,7 +1110,6 @@ TEST(Simplify, SimplifyDiv) { } TEST(Simplify, SimplifyDivWithLoopContext1) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // A[i] = (i + 24) / 6; @@ -1166,7 +1131,6 @@ TEST(Simplify, SimplifyDivWithLoopContext1) { } TEST(Simplify, SimplifyDivWithLoopContext2) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 5; i++) { // A[i] = (i + 25) / 6; @@ -1188,7 +1152,6 @@ TEST(Simplify, SimplifyDivWithLoopContext2) { } TEST(Simplify, SimplifyDivWithLoopContext3) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // A[i] = (i + 24) / (-6); @@ -1210,7 +1173,6 @@ TEST(Simplify, SimplifyDivWithLoopContext3) { } TEST(Simplify, SimplifyDivWithLoopContext4) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 5; i++) { // A[i] = (i - 5) / 6; @@ -1232,7 +1194,6 @@ TEST(Simplify, SimplifyDivWithLoopContext4) { } TEST(Simplify, SimplifyDivWithLoopContext5) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = 0; j < 10; j++) { @@ -1259,7 +1220,6 @@ TEST(Simplify, SimplifyDivWithLoopContext5) { } TEST(Simplify, SimplifyDivWithLoopContext6) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = -1; j < 9; j++) { @@ -1287,7 +1247,6 @@ TEST(Simplify, SimplifyDivWithLoopContext6) { } TEST(Simplify, SimplifyDivWithLoopContext7) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = 0; j < 10; j++) { @@ -1315,7 +1274,6 @@ TEST(Simplify, SimplifyDivWithLoopContext7) { } TEST(Simplify, SimplifyModWithLoopContext0) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 100; i++) { // A[i] = i % 100; @@ -1337,7 +1295,6 @@ TEST(Simplify, SimplifyModWithLoopContext0) { } TEST(Simplify, SimplifyModWithLoopContext1) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // A[i] = (i + 24) % 6; @@ -1359,7 +1316,6 @@ TEST(Simplify, SimplifyModWithLoopContext1) { } TEST(Simplify, SimplifyModWithLoopContext2) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 5; i++) { // A[i] = (i + 25) % 6; @@ -1381,7 +1337,6 @@ TEST(Simplify, SimplifyModWithLoopContext2) { } TEST(Simplify, SimplifyModWithLoopContext3) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // A[i] = (i + 24) % (-6); @@ -1403,7 +1358,6 @@ TEST(Simplify, SimplifyModWithLoopContext3) { } TEST(Simplify, SimplifyModWithLoopContext4) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 5; i++) { // A[i] = (i - 5) % 6; @@ -1425,7 +1379,6 @@ TEST(Simplify, SimplifyModWithLoopContext4) { } TEST(Simplify, SimplifyModWithLoopContext5) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = 0; j < 10; j++) { @@ -1452,7 +1405,6 @@ TEST(Simplify, SimplifyModWithLoopContext5) { } TEST(Simplify, SimplifyModWithLoopContext6) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = -1; j < 9; j++) { @@ -1480,7 +1432,6 @@ TEST(Simplify, SimplifyModWithLoopContext6) { } TEST(Simplify, SimplifyModWithLoopContext7) { - KernelScope kernel_scope; // Stmt to simplify: // for (int i = 0; i < 6; i++) { // for (int j = 0; j < 10; j++) { @@ -1508,7 +1459,6 @@ TEST(Simplify, SimplifyModWithLoopContext7) { } TEST(Simplify, SimplifyMod) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -1635,7 +1585,6 @@ TEST(Simplify, SimplifyMod) { // Test that mixing ops together simplifies as expected. TEST(Simplify, SimplifyMultiOp) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1654,14 +1603,14 @@ TEST(Simplify, SimplifyMultiOp) { } { - // (x + y) - (x * y) => x + y - (x * y) - ExprHandle body = (x + y) - (x * y); + // (x + y) - x * y => (x + y) - x * y + ExprHandle body = (x + y) - x * y; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Sub, simplified.node(), sub); IS_NODE_WITH_NAME(Add, sub->lhs(), add); IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); IS_VAR_WITH_NAME(mul->lhs(), "x"); IS_VAR_WITH_NAME(mul->rhs(), "y"); } @@ -1704,24 +1653,23 @@ TEST(Simplify, SimplifyMultiOp) { // Test that chaining many ops together works as expected. TEST(Simplify, SimplifyManyOps) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); { - // x + y + x + x + y + y + x + y + x = 5 * x + 4 * y + // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x ExprHandle body = x + y + x + x + y + y + x + y + x; ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); - IS_VAR_WITH_NAME(rhs->rhs(), "y"); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); } { @@ -1752,7 +1700,6 @@ TEST(Simplify, SimplifyManyOps) { } TEST(Simplify, SimplifyFactorization) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -1765,8 +1712,8 @@ TEST(Simplify, SimplifyFactorization) { IS_IMM_WITH_VAL(Int, mul->lhs(), 2); IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); } { @@ -1794,12 +1741,12 @@ TEST(Simplify, SimplifyFactorization) { IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); - IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 2); - IS_VAR_WITH_NAME(rhs->rhs(), "x"); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); } { @@ -1813,8 +1760,8 @@ TEST(Simplify, SimplifyFactorization) { IS_IMM_WITH_VAL(Int, mul->lhs(), 10); IS_NODE_WITH_NAME(Add, mul->rhs(), add); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); } { @@ -1863,24 +1810,17 @@ TEST(Simplify, SimplifyFactorization) { VarHandle g("g", kInt); VarHandle h("h", kInt); - ExprHandle body = ExprHandle(0) + (ExprHandle(1024) * a) + - (ExprHandle(-1) * b) + (ExprHandle(-1) * c) + (ExprHandle(1) * d) + - (ExprHandle(1) * e) + (ExprHandle(32) * f) + (ExprHandle(-1024) * g) + - (ExprHandle(-32) * h); + ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + + f * 32 + g * (-1024) + h * (-32); ExprHandle simplified = IRSimplifier::simplify(body); - - // We only check for the top level nodes here, since the main purpose - // here is ensure that this simplification completes. - IS_NODE_WITH_NAME(Sub, simplified.node(), sub); - IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 1024); - IS_VAR_WITH_NAME(mul->rhs(), "g"); + checkExprIR( + simplified, + "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); } } // (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) TEST(Simplify, SimplifyFactorizeUneven) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -1904,10 +1844,9 @@ TEST(Simplify, SimplifyFactorizeUneven) { IS_VAR_WITH_NAME(zmul->rhs(), "z"); } -// (x * y) + (2 * x) * (x + y) => 3 * (x * y) + 2 * (x * x) +// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) // This is kind of a placeholder test for variable factorization. TEST(Simplify, SimplifyDeeperTerms) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); @@ -1916,22 +1855,21 @@ TEST(Simplify, SimplifyDeeperTerms) { IS_NODE_WITH_NAME(Add, simplified.node(), add); IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); - IS_IMM_WITH_VAL(Int, lhs->lhs(), 3); - IS_NODE_WITH_NAME(Mul, lhs->rhs(), xyTerm); - IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); - IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); + IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); - IS_IMM_WITH_VAL(Int, rhs->lhs(), 2); - IS_NODE_WITH_NAME(Mul, rhs->rhs(), xxTerm); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); - IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); + IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); } // Tests the difference between two less trivial expressions. // (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 TEST(Simplify, SimplifyDeeperDifference) { - KernelScope kernel_scope; VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); VarHandle m("m", kInt); @@ -1945,7 +1883,6 @@ TEST(Simplify, SimplifyDeeperDifference) { // Test constant folding into the difference between expressions. // 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 TEST(Simplify, SimplifyFoldComplexDifference) { - KernelScope kernel_scope; VarHandle n("n", kInt); VarHandle n_1("n_1", kInt); VarHandle m("m", kInt); @@ -1960,7 +1897,6 @@ TEST(Simplify, SimplifyFoldComplexDifference) { } TEST(Simplify, SimplifyIfComponents) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = IfThenElse::make( @@ -1982,20 +1918,19 @@ TEST(Simplify, SimplifyIfComponents) { } TEST(Simplify, SimplifyOpaqueTerms) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); { - // 2 * x/y * x - x/y * y => y * x/y + // 2 * x/y * y - x/y * y => x/y * y ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Mul, simplified.node(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "y"); - IS_NODE_WITH_NAME(Div, mul->rhs(), div); + IS_NODE_WITH_NAME(Div, mul->lhs(), div); IS_VAR_WITH_NAME(div->lhs(), "x"); IS_VAR_WITH_NAME(div->rhs(), "y"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); } { @@ -2008,8 +1943,6 @@ TEST(Simplify, SimplifyOpaqueTerms) { } TEST(Simplify, SimplifySymbolicMinMax) { - KernelScope kernel_scope; - { // Minimum with constant difference between terms. VarHandle x("x", kInt); @@ -2044,7 +1977,6 @@ TEST(Simplify, SimplifySymbolicMinMax) { } TEST(Simplify, SimplifyNestedMax) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2055,46 +1987,46 @@ TEST(Simplify, SimplifyNestedMax) { ExprHandle simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "y", "x"); + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); } { - // Max(x + y, Max(x + y, z)) => Max(y + x, z) + // Max(x + y, Max(x + y, z)) => Max(x + y, z) ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { - // Max(x + y, Max(z, x + y)) => Max(y + x, z) + // Max(x + y, Max(z, x + y)) => Max(x + y, z) ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { - // Max(Max(x + y, z), x + y) => Max(y + x, z) + // Max(Max(x + y, z), x + y) => Max(x + y, z) ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } { - // Max(Max(z, x + y), x + y) => Max(y + x, z) + // Max(Max(z, x + y), x + y) => Max(x + y, z) ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Add, max->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(max->rhs(), "z"); } @@ -2112,55 +2044,39 @@ TEST(Simplify, SimplifyNestedMax) { } { - // Max(Min(x, y), Min(x, z)) => Min(x, Max(y, z)) + // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(x, y, true), Min::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_VAR_WITH_NAME(min->lhs(), "x"); - IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z"); - ASSERT_TRUE(max->propagate_nans()); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { - // Max(Min(x, y), Min(z, x)) => Min(x, Max(y, z)) + // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(x, y, true), Min::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_VAR_WITH_NAME(min->lhs(), "x"); - IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z"); - ASSERT_TRUE(max->propagate_nans()); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { - // Max(Min(y, x), Min(x, z)) => Min(x, Max(y, z)) + // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(y, x, true), Min::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_VAR_WITH_NAME(min->lhs(), "x"); - IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z"); - ASSERT_TRUE(max->propagate_nans()); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { - // Max(Min(y, x), Min(z, x)) => Min(x, Max(y, z)) + // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) ExprHandle body = Max::make(Min::make(y, x, true), Min::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_VAR_WITH_NAME(min->lhs(), "x"); - IS_BINOP_W_VARS(Max, min->rhs(), max, "y", "z"); - ASSERT_TRUE(max->propagate_nans()); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); } { - // Max(Min(y, x), Min(z, x)) => Max(Min(x, z), Min(x, y)) + // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) // When all the ops in the pattern do not have the same propagate_nans, // it should not be simplified. ExprHandle body = @@ -2168,10 +2084,10 @@ TEST(Simplify, SimplifyNestedMax) { ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "z"); - ASSERT_FALSE(min1->propagate_nans()); - IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "y"); - ASSERT_TRUE(min2->propagate_nans()); + IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); + ASSERT_TRUE(min1->propagate_nans()); + IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); + ASSERT_FALSE(min2->propagate_nans()); ASSERT_TRUE(max->propagate_nans()); } @@ -2304,18 +2220,7 @@ TEST(Simplify, SimplifyNestedMax) { 8, false); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max1); - IS_NODE_WITH_NAME(Max, max1->lhs(), max2); - IS_VAR_WITH_NAME(max2->lhs(), "x"); - IS_NODE_WITH_NAME(Max, max2->rhs(), max3); - IS_BINOP_W_CONST(Max, max3->lhs(), max4, "z", 5); - ASSERT_TRUE(max4->propagate_nans()); - IS_VAR_WITH_NAME(max3->rhs(), "y"); - ASSERT_FALSE(max3->propagate_nans()); - ASSERT_TRUE(max2->propagate_nans()); - IS_IMM_WITH_VAL(Int, max1->rhs(), 8); - ASSERT_FALSE(max1->propagate_nans()); + checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); } { @@ -2348,7 +2253,6 @@ TEST(Simplify, SimplifyNestedMax) { } TEST(Simplify, SimplifyNestedMin) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2359,46 +2263,46 @@ TEST(Simplify, SimplifyNestedMin) { ExprHandle simplified = IRSimplifier::simplify(body); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - IS_BINOP_W_VARS(Add, simplified.node(), add, "y", "x"); + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); } { - // Min(x + y, Min(x + y, z)) => Min(y + x, z) + // Min(x + y, Min(x + y, z)) => Min(x + y, z) ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { - // Min(x + y, Min(z, x + y)) => Min(y + x, z) + // Min(x + y, Min(z, x + y)) => Min(x + y, z) ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { - // Min(Min(x + y, z), x + y) => Min(y + x, z) + // Min(Min(x + y, z), x + y) => Min(x + y, z) ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } { - // Min(Min(z, x + y), x + y) => Min(y + x, z) + // Min(Min(z, x + y), x + y) => Min(x + y, z) ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Add, min->lhs(), add, "y", "x"); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); IS_VAR_WITH_NAME(min->rhs(), "z"); } @@ -2416,55 +2320,39 @@ TEST(Simplify, SimplifyNestedMin) { } { - // Min(Max(x, y), Max(x, z)) => Max(x, Min(y, z)) + // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(x, y, true), Max::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_VAR_WITH_NAME(max->lhs(), "x"); - IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z"); - ASSERT_TRUE(min->propagate_nans()); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { - // Min(Max(x, y), Max(z, x)) => Max(x, Min(y, z)) + // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(x, y, true), Max::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_VAR_WITH_NAME(max->lhs(), "x"); - IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z"); - ASSERT_TRUE(min->propagate_nans()); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { - // Min(Max(y, x), Max(x, z)) => Max(x, Min(y, z)) + // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(y, x, true), Max::make(x, z, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_VAR_WITH_NAME(max->lhs(), "x"); - IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z"); - ASSERT_TRUE(min->propagate_nans()); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { - // Min(Max(y, x), Max(z, x)) => Max(x, Min(y, z)) + // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) ExprHandle body = Min::make(Max::make(y, x, true), Max::make(z, x, true), true); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Max, simplified.node(), max); - IS_VAR_WITH_NAME(max->lhs(), "x"); - IS_BINOP_W_VARS(Min, max->rhs(), min, "y", "z"); - ASSERT_TRUE(min->propagate_nans()); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); } { - // Min(Max(y, x), Max(z, x)) => Min(Max(x, z), Max(x, y)) + // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) // When all the ops in the pattern do not have the same propagate_nans, // it should not be simplified. ExprHandle body = @@ -2472,10 +2360,10 @@ TEST(Simplify, SimplifyNestedMin) { ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Min, simplified.node(), min); - IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "z"); - ASSERT_FALSE(max1->propagate_nans()); - IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "y"); - ASSERT_TRUE(max2->propagate_nans()); + IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); + ASSERT_TRUE(max1->propagate_nans()); + IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); + ASSERT_FALSE(max2->propagate_nans()); ASSERT_TRUE(min->propagate_nans()); } @@ -2600,7 +2488,7 @@ TEST(Simplify, SimplifyNestedMin) { } { - // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(x, Min(Min(z, 5), y)), 8) + // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) // Do not simplify when all the Min ops do not have the same // propagate_nans. ExprHandle body = Min::make( @@ -2608,18 +2496,7 @@ TEST(Simplify, SimplifyNestedMin) { 8, false); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Min, simplified.node(), min1); - IS_NODE_WITH_NAME(Min, min1->lhs(), min2); - IS_VAR_WITH_NAME(min2->lhs(), "x"); - IS_NODE_WITH_NAME(Min, min2->rhs(), min3); - IS_BINOP_W_CONST(Min, min3->lhs(), min4, "z", 5); - ASSERT_TRUE(min4->propagate_nans()); - IS_VAR_WITH_NAME(min3->rhs(), "y"); - ASSERT_FALSE(min3->propagate_nans()); - ASSERT_TRUE(min2->propagate_nans()); - IS_IMM_WITH_VAL(Int, min1->rhs(), 8); - ASSERT_FALSE(min1->propagate_nans()); + checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); } { @@ -2652,8 +2529,6 @@ TEST(Simplify, SimplifyNestedMin) { } TEST(Simplify, SimplifyWontReorderFloat) { - KernelScope kernel_scope; - { // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) // This is an expression we can simplify. @@ -2764,8 +2639,6 @@ TEST(Simplify, SimplifyWontReorderFloat) { } TEST(Simplify, SimplifyRoundModPattern) { - KernelScope kernel_scope; - { // (x/y)*y + x%y => x. VarHandle x("x", kInt); @@ -2922,16 +2795,7 @@ TEST(Simplify, SimplifyRoundModPattern) { VarHandle z("z", kInt); ExprHandle body = ((x / y) * y) + (x % z); ExprHandle simplified = IRSimplifier::simplify(body); - - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_VAR_WITH_NAME(roundMul->lhs(), "y"); - IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "z"); + checkExprIR(simplified, "(x / y) * y + x % z"); } { @@ -2941,15 +2805,7 @@ TEST(Simplify, SimplifyRoundModPattern) { VarHandle z("z", kInt); ExprHandle body = (y * (x / z)) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_VAR_WITH_NAME(roundMul->lhs(), "y"); - IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "z"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); + checkExprIR(simplified, "x % y + (x / z) * y"); } { @@ -2959,21 +2815,11 @@ TEST(Simplify, SimplifyRoundModPattern) { VarHandle z("z", kInt); ExprHandle body = ((x / y) * z) + (x % y); ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); - IS_VAR_WITH_NAME(roundMul->lhs(), "z"); - IS_NODE_WITH_NAME(Div, roundMul->rhs(), roundDiv); - IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); - IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "x"); - IS_VAR_WITH_NAME(mod->rhs(), "y"); + checkExprIR(simplified, "x % y + (x / y) * z"); } } TEST(Simplify, SimplifyRoundModPatternFactorization) { - KernelScope kernel_scope; - { // Full factorization. // 2 * (x/y * y) + 2 * (x%y) => 2 * x. @@ -3032,24 +2878,22 @@ TEST(Simplify, SimplifyRoundModPatternFactorization) { } TEST(Simplify, SimplifyRoundModPatternMultivar) { - KernelScope kernel_scope; - { // Multivar. - // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => y + x. + // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. VarHandle x("x", kInt); VarHandle y("y", kInt); ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); ExprHandle simplified = IRSimplifier::simplify(body); IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->lhs(), "y"); - IS_VAR_WITH_NAME(add->rhs(), "x"); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); } { // Find the right var. - // (y/8) * 8 x%8 + y%8 + z%8 => z%8 + x%8 + y + // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -3075,22 +2919,13 @@ TEST(Simplify, SimplifyRoundModPatternMultivar) { VarHandle y("y", kInt); VarHandle z("z", kInt); - ExprHandle body = x + (z + ExprHandle(512) * y) % ExprHandle(16) + - ExprHandle(16) * ((z + ExprHandle(512) * y) / ExprHandle(16)); + ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_VAR_WITH_NAME(add->rhs(), "x"); - IS_NODE_WITH_NAME(Add, add->lhs(), add2); - IS_VAR_WITH_NAME(add2->lhs(), "z"); - IS_NODE_WITH_NAME(Mul, add2->rhs(), mul); - IS_IMM_WITH_VAL(Int, mul->lhs(), 512); - IS_VAR_WITH_NAME(mul->rhs(), "y"); + checkExprIR(simplified, "x + (z + 512 * y)"); } } TEST(Simplify, SimplifyModRoundModPattern) { - KernelScope kernel_scope; - { // t/7 % 9 * 7 + t % 7 => t%63 VarHandle t("t", kInt); @@ -3135,13 +2970,7 @@ TEST(Simplify, SimplifyModRoundModPattern) { VarHandle k("k", kInt); ExprHandle body = (k * t / x % y) * x + k * t % x; ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Mod, simplified.node(), mod); - IS_NODE_WITH_NAME(Mul, mod->lhs(), mul1); - IS_VAR_WITH_NAME(mul1->lhs(), "t"); - IS_VAR_WITH_NAME(mul1->rhs(), "k"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); + checkExprIR(simplified, "(k * t) % (x * y)"); } { @@ -3183,8 +3012,6 @@ TEST(Simplify, SimplifyModRoundModPattern) { } TEST(Simplify, SimplifyModRoundModPatternFactorization) { - KernelScope kernel_scope; - { // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) VarHandle t("t", kInt); @@ -3252,18 +3079,12 @@ TEST(Simplify, SimplifyModRoundModPatternFactorization) { } TEST(Simplify, SimplifyModRoundModPatternMultivar) { - KernelScope kernel_scope; - { // t/7 % 9 * 7 + t % 7 + t => t % 63 + t VarHandle t("t", kInt); ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_IMM_WITH_VAL(Int, mod->rhs(), 63); - IS_VAR_WITH_NAME(add->lhs(), "t"); + checkExprIR(simplified, "t % 63 + t"); } { @@ -3306,19 +3127,7 @@ TEST(Simplify, SimplifyModRoundModPatternMultivar) { VarHandle k("k", kInt); ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; ExprHandle simplified = IRSimplifier::simplify(body); - IS_NODE_WITH_NAME(Add, simplified.node(), add); - IS_NODE_WITH_NAME(Mod, add->lhs(), mod); - IS_VAR_WITH_NAME(mod->lhs(), "t"); - IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); - IS_VAR_WITH_NAME(mul->lhs(), "x"); - IS_VAR_WITH_NAME(mul->rhs(), "y"); - IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); - IS_NODE_WITH_NAME(Div, mod2->lhs(), div); - IS_VAR_WITH_NAME(div->lhs(), "t"); - IS_VAR_WITH_NAME(div->rhs(), "k"); - IS_NODE_WITH_NAME(Mul, mod2->rhs(), mul2); - IS_VAR_WITH_NAME(mul2->lhs(), "x"); - IS_VAR_WITH_NAME(mul2->rhs(), "y"); + checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); } { @@ -3374,8 +3183,6 @@ TEST(Simplify, SimplifyModRoundModPatternMultivar) { } TEST(Simplify, SimplifyDivisionScalarFactorization) { - KernelScope kernel_scope; - { // Simple factorization of numerator and denominator. // 8x / 4y => 2x / y. @@ -3446,8 +3253,6 @@ TEST(Simplify, SimplifyDivisionScalarFactorization) { } TEST(Simplify, SimplifyConstantBranches) { - KernelScope kernel_scope; - { // If the condition is constant true then take the true_value. // 1 ? x : y => x @@ -3504,8 +3309,6 @@ TEST(Simplify, SimplifyConstantBranches) { } TEST(Simplify, SimplifyConstantCond) { - KernelScope kernel_scope; - { // If the condition is constant true then take the true_value. // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 @@ -3622,7 +3425,6 @@ TEST(Simplify, SimplifyConstantCond) { } TEST(Simplify, SimplifyEliminateEmptyCond) { - KernelScope kernel_scope; // If the branches are empty in different ways, eliminate. { VarHandle x("x", kInt); @@ -3650,8 +3452,6 @@ TEST(Simplify, SimplifyEliminateEmptyCond) { } TEST(Simplify, SimplifyConstantComparisons) { - KernelScope kernel_scope; - auto ComparisonTest = [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { ExprHandle body = CompareSelect::make(a, b, op); @@ -3696,7 +3496,6 @@ TEST(Simplify, SimplifyConstantComparisons) { } TEST(Simplify, SimplifySymbolicComparisons) { - KernelScope kernel_scope; VarHandle x("x", kInt); VarHandle y("y", kInt); @@ -3834,8 +3633,6 @@ TEST(Simplify, SimplifySymbolicComparisons) { } TEST(Simplify, SimplifyEliminateZeroLengthFor) { - KernelScope kernel_scope; - { // Will eliminate zero loop For. BufHandle a("A", {4}, kInt); @@ -3894,8 +3691,6 @@ TEST(Simplify, SimplifyEliminateZeroLengthFor) { } TEST(Simplify, SimplifyOneLoopFor) { - KernelScope kernel_scope; - { // Will remove the loop if the body is run once. BufHandle a("A", {4}, kInt); @@ -3963,15 +3758,13 @@ TEST(Simplify, SimplifyOneLoopFor) { } TEST(Simplify, SimplifyForWontLoseLoopOptions) { - KernelScope kernel_scope; - { // Sanity check does nothing if the condition is not met. BufHandle a("A", {4}, kInt); BufHandle c("C", {4}, kInt); VarHandle i("i", kInt); LoopOptions options; - options.set_gpu_block_index(12); + options.set_gpu_block_index(LoopOptions::IDX_W); auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); StmtPtr simplified = IRSimplifier::simplify(body); @@ -3982,8 +3775,6 @@ TEST(Simplify, SimplifyForWontLoseLoopOptions) { } TEST(Simplify, SimplifyMultilevelFor) { - KernelScope kernel_scope; - { // Multiple layers of For will be simplified out. BufHandle a("A", {4}, kInt); @@ -4041,12 +3832,10 @@ TEST(Simplify, SimplifyMultilevelFor) { } TEST(Simplify, SimplifyForCleansUp) { - KernelScope kernel_scope; - { Placeholder a("a", kFloat, {1, 12, 1}); VarHandle x("x", kInt); - Tensor* b = Compute( + Tensor b = Compute( // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) "x", {{1, "i"}, {12, "m"}, {1, "n"}}, @@ -4071,8 +3860,6 @@ TEST(Simplify, SimplifyForCleansUp) { } TEST(Simplify, SimplifyEliminateEmptyFor) { - KernelScope kernel_scope; - { // Flatten many layers around an empty block to an empty block. StmtPtr last = alloc(std::vector({})); @@ -4088,8 +3875,6 @@ TEST(Simplify, SimplifyEliminateEmptyFor) { } TEST(Simplify, SimplifyFlattenBlock) { - KernelScope kernel_scope; - { // Flatten multiple blocks down to one. // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } @@ -4173,8 +3958,6 @@ TEST(Simplify, SimplifyFlattenBlock) { } TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { - KernelScope kernel_scope; - { // Simple positive case. BufHandle b("x", {0}, kInt); @@ -4249,8 +4032,6 @@ TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { } TEST(Simplify, DontSimplifyRand) { - KernelScope kernel_scope; - { // rand() + rand() = rand() + rand() NOT 2 * rand(). ExprHandle body = @@ -4283,7 +4064,6 @@ TEST(Simplify, DontSimplifyRand) { } TEST(Simplify, SimplifyReorderForCond) { - KernelScope kernel_scope; BufHandle a("A", {4}, kInt); BufHandle b("B", {1}, kInt); BufHandle c("C", {4}, kInt); @@ -4482,7 +4262,6 @@ TEST(Simplify, SimplifyReorderForCond) { } TEST(Simplify, SimplifyFuseConditions) { - KernelScope kernel_scope; BufHandle a("A", {2}, kInt); BufHandle b("B", {2}, kInt); VarHandle i("i", kInt); @@ -4892,7 +4671,6 @@ TEST(Simplify, SimplifyFuseConditions) { } TEST(Simplify, SimplifySyncThreads) { - KernelScope kernel_scope; BufHandle a("A", {4}, kInt); VarHandle i("i", kInt); @@ -4990,7 +4768,6 @@ TEST(Simplify, SimplifySyncThreads) { } TEST(Simplify, SimplifyRampSubBroadcast) { - KernelScope kernel_scope; int num_lanes = 4; ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); @@ -5004,7 +4781,6 @@ TEST(Simplify, SimplifyRampSubBroadcast) { } TEST(Simplify, SimplifyBroadcastTermExpander) { - KernelScope kernel_scope; int num_lanes = 8; ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); @@ -5034,7 +4810,6 @@ TEST(Simplify, DISABLED_CompareSelectCondAlwaysInLoopBounds) { // for (int n = 1; n < N; n++) { // b[n] = 1.f; // } - KernelScope kernel_scope; constexpr int N = 8; Placeholder b("b", kFloat, {N}); VarHandle n("n", kInt); @@ -5059,7 +4834,6 @@ TEST(Simplify, DISABLED_IfThenCondAlwaysInLoopBounds) { // for (int n = 1; n < N; n++) { // b[n] = 1.f; // } - KernelScope kernel_scope; constexpr int N = 8; Placeholder b("b", kFloat, {N}); VarHandle n("n", kInt); @@ -5088,7 +4862,6 @@ TEST(Simplify, DISABLED_MultiClauseCondAlwaysInLoopBounds) { // for (int i = 1; i < 7; i++) { // for (int j = 1; j < 7; j++) { // b[i, j] = 1.f; - KernelScope kernel_scope; constexpr int N = 8; Placeholder b("b", kFloat, {N, N}); VarHandle i("i", kInt); @@ -5124,7 +4897,6 @@ TEST(Simplify, DISABLED_SimplifyLoopBounds) { // for (int i = 1; i < 3; i++) { // for (int j = 1; j < 3; j++) { // b[i, j] = (b[i, j]) + 1.f; - KernelScope kernel_scope; constexpr int N = 8; constexpr int K = 3; Placeholder a("a", kFloat, {N, N}); diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp index 8dd616453362b..b82d383bc99b0 100644 --- a/test/cpp/tensorexpr/test_te_fuser_pass.cpp +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -15,19 +14,15 @@ namespace jit { using namespace torch::jit::tensorexpr; struct WithCPUFuser { - WithCPUFuser(bool val = true) - : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) { + WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { overrideCanFuseOnCPU(val); - setTexprParallelCPUEnabled(true); } ~WithCPUFuser() { overrideCanFuseOnCPU(cpuFuserEnabled); - setTexprParallelCPUEnabled(parallel); } bool cpuFuserEnabled; - bool parallel; }; TEST(TEFuserPass, FuserPass_1) { diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp index cc8a6967b7255..67c1a0a528b7c 100644 --- a/test/cpp/tensorexpr/test_type.cpp +++ b/test/cpp/tensorexpr/test_type.cpp @@ -9,7 +9,6 @@ namespace jit { using namespace torch::jit::tensorexpr; TEST(Type, Test01) { - KernelScope kernel_scope; { Dtype dt1 = kInt; ASSERT_EQ(dt1, kInt); @@ -45,28 +44,24 @@ TEST(Type, Test01) { TEST(Type, BitCasting) { { - KernelScope kernel_scope; VarHandle x("x", kFloat); ExprHandle y = bitcast(x); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ASSERT_EQ(y.dtype(), kInt); } { - KernelScope kernel_scope; VarHandle x("x", kInt); ExprHandle y = bitcast(x); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ASSERT_EQ(y.dtype(), kFloat); } { - KernelScope kernel_scope; VarHandle x("x", kShort); ExprHandle y = bitcast(x); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) ASSERT_EQ(y.dtype(), kHalf); } { - KernelScope kernel_scope; VarHandle x("x", kHalf); ExprHandle y = bitcast(x); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) @@ -82,7 +77,6 @@ TEST(Type, BitCasting) { using SimpleIRExprEval = ExprEval; // this is broken /*{ - KernelScope kernel_scope; at::Half k_; at::Half* k = &k_; *reinterpret_cast(k) = ref16; @@ -93,7 +87,6 @@ TEST(Type, BitCasting) { }*/ { - KernelScope kernel_scope; float k = raw_bitcast(ref32); auto a = FloatImm::make(k); auto b = BitCast::make(kInt, a); @@ -102,7 +95,6 @@ TEST(Type, BitCasting) { } { - KernelScope kernel_scope; double k = raw_bitcast(ref64); auto a = DoubleImm::make(k); auto b = BitCast::make(kLong, a); @@ -111,7 +103,6 @@ TEST(Type, BitCasting) { } { - KernelScope kernel_scope; int64_t k = raw_bitcast(reff64); auto a = LongImm::make(k); auto b = BitCast::make(kDouble, a); @@ -120,7 +111,6 @@ TEST(Type, BitCasting) { } { - KernelScope kernel_scope; int32_t k = raw_bitcast(reff32); auto a = IntImm::make(k); auto b = BitCast::make(kFloat, a); @@ -130,27 +120,22 @@ TEST(Type, BitCasting) { // This segfaults :( /*{ - KernelScope kernel_scope; VarHandle x("x", kDouble); ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); } { - KernelScope kernel_scope; VarHandle x("x", kFloat); ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); } { - KernelScope kernel_scope; VarHandle x("x", kLong); ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); } { - KernelScope kernel_scope; VarHandle x("x", kShort); ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); } { - KernelScope kernel_scope; VarHandle x("x", kInt); ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); }*/ @@ -159,7 +144,6 @@ TEST(Type, BitCasting) { TEST(Type, Propagation) { // Same types: { - KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = FloatImm::make(2.f) + @@ -168,7 +152,6 @@ TEST(Type, Propagation) { } // Int to bigger int: { - KernelScope kernel_scope; VarHandle x("x", kShort); VarHandle y("y", kLong); ExprHandle body = @@ -177,7 +160,6 @@ TEST(Type, Propagation) { } // Float to bigger float: { - KernelScope kernel_scope; VarHandle x("x", kHalf); VarHandle y("y", kDouble); ExprHandle body = @@ -186,7 +168,6 @@ TEST(Type, Propagation) { } // Int to Float: { - KernelScope kernel_scope; VarHandle x("x", kFloat); VarHandle y("y", kInt); ExprHandle body = @@ -195,7 +176,6 @@ TEST(Type, Propagation) { } // Smaller float, bigger Int: { - KernelScope kernel_scope; VarHandle x("x", kHalf); VarHandle y("y", kLong); ExprHandle body = @@ -204,7 +184,6 @@ TEST(Type, Propagation) { } // Bigger float, smaller Int: { - KernelScope kernel_scope; VarHandle x("x", kChar); VarHandle y("y", kDouble); ExprHandle body = @@ -213,7 +192,6 @@ TEST(Type, Propagation) { } // Sign change char/byte upgrades to short: { - KernelScope kernel_scope; VarHandle x("x", kChar); VarHandle y("y", kByte); ExprHandle body = diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h index 01b92a7832a40..065e513c1a645 100644 --- a/test/cpp/tensorexpr/test_utils.h +++ b/test/cpp/tensorexpr/test_utils.h @@ -4,6 +4,7 @@ #include #include +#include #include namespace torch { @@ -69,5 +70,9 @@ using namespace torch::jit::tensorexpr; ASSERT_EQ(node_->op_type(), kRand); \ } +void checkIR(StmtPtr s, const std::string& pattern); +void checkExprIR(ExprPtr e, const std::string& pattern); +void checkExprIR(const ExprHandle& e, const std::string& pattern); + } // namespace jit } // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index 9320f47bfb3d8..0ec0968bebf8f 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -38,34 +38,30 @@ #include #include +#include +#include #include #include #include #include +#include +#include #include #include #include +#include using namespace torch::jit::tensorexpr; +// Helper function to print a snippet from a big multi-line string +static void printLinesToFrom(const std::string& input_str, int from, int to); + int main(int argc, char* argv[]) { - // Memory management for tensor expressions is currently done with memory - // arenas. That is, whenever an object is created it registers itself in an - // arena and the object is kept alive as long as the arena is alive. When the - // arena gets destructed, it deletes all objects registered in it. - // - // The easiest way to set up a memory arena is to use `KernelScope` class - it - // is a resource guard that creates a new arena on construction and restores - // the previously set arena on destruction. - // - // We will create a kernel scope here, and thus we'll set up a mem arena for - // the entire tutorial. - KernelScope kernel_scope; - - std::cout << "*** Structure of tensor expressions ***" << std::endl; + std::cout << "*** Structure of tensor expressions and statements ***" + << std::endl; { // A tensor expression is a tree of expressions. Each expression has a type, - // and that type defines what sub-expressions it the current expression has. + // and that type defines what sub-expressions the current expression has. // For instance, an expression of type 'Mul' would have a type 'kMul' and // two subexpressions: LHS and RHS. Each of these two sub-expressions could // also be a 'Mul' or some other expression. @@ -85,15 +81,21 @@ int main(int argc, char* argv[]) { // like we did in the previous example). Expression handles overload common // operations and allow us to express the same semantics in a more natural // way: - ExprHandle l = 1; + ExprHandle l = 5; ExprHandle r = Var::make("x", kInt); ExprHandle m = l * r; std::cout << "Tensor expression: " << *m.node() << std::endl; - // Prints: Tensor expression: 1 * x + // Prints: Tensor expression: 5 * x + + // Converting from handles to raw expressions and back is easy: + ExprHandle handle = Var::make("x", kInt); + ExprPtr raw_expr_from_handle = handle.node(); + ExprPtr raw_expr = alloc("x", kInt); + ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); - // In a similar fashion we could construct arbitrarily complex expressions - // using mathematical and logical operations, casts between various data - // types, and a bunch of intrinsics. + // We could construct arbitrarily complex expressions using mathematical + // and logical operations, casts between various data types, and a bunch of + // intrinsics. ExprHandle a = Var::make("a", kInt); ExprHandle b = Var::make("b", kFloat); ExprHandle c = Var::make("c", kFloat); @@ -109,238 +111,232 @@ int main(int argc, char* argv[]) { // placeholder similar to Var, but with dimensions info. // // Let's construct a simple load: - BufHandle A("A", {ExprHandle(64), ExprHandle(32)}, kInt); - ExprHandle i = Var::make("i", kInt), j = Var::make("j", kInt); + BufHandle A("A", {64, 32}, kInt); + VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); + ExprHandle i(i_var), j(j_var); ExprHandle load = Load::make(A.dtype(), A, {i, j}); std::cout << "Tensor expression: " << *load.node() << std::endl; // Prints: Tensor expression: A[i, j] - } - std::cout << "*** Tensors, Functions, and Placeholders ***" << std::endl; - { - // A tensor computation is represented by Tensor class objects and - // consists of the following pieces: - // - domain, which is specified by a Buf expression - // - a tensor statement, specified by a Stmt object, that computation to - // be performed in this domain - - // Let's start with defining a domain. We do this by creating a Buf object. - - // First, let's specify the sizes: - std::vector dims = { - alloc(64), - alloc(32)}; // IntImm stands for Integer Immediate - // and represents an integer constant - - // Now we can create a Buf object by providing a name, dimensions, and a - // data type of the elements: - BufPtr buf = alloc("X", dims, kInt); - - // Next we need to spefify the computation. We can do that by either - // constructing a complete tensor statement for it (statements are - // examined in details in subsequent section), or by using a convenience - // method where we could specify axis and an element expression for the - // computation. In the latter case a corresponding statement would be - // constructed automatically. - - // Let's define two variables, i and j - they will be axis in our - // computation. - VarPtr i = alloc("i", kInt); - VarPtr j = alloc("j", kInt); - std::vector args = {i, j}; - - // Now we can define the body of the tensor computation using these - // variables. What this means is that values in our tensor are: - // X[i, j] = i * j - ExprPtr body = alloc(i, j); - - // Finally, we pass all these pieces together to Tensor constructor: - Tensor* X = new Tensor(buf, args, body); - std::cout << "Tensor computation: " << *X << std::endl; + // Tensor Expressions constitute Tensor Statements, which are used to + // represent computation of a given operator or a group of operators from a + // fusion group. + // + // There are three main kinds of tensor statements: + // - block + // - store + // - loop + // + // A Store represents a store to a single element of a tensor (or to a + // group of elements if it's a vectorized store). Store statements, + // similarly to Load expressions, have a base and indices, but on top of + // that they also include a value - an expression representing what needs + // to be stored at the given memory location. Let's create a Store stmt: + StmtPtr store_a = Store::make(A, {i, j}, i + j); + std::cout << "Store statement: " << *store_a << std::endl; + // Prints: Store statement: A[i, j] = i + j; + + // An operator fills the entire tensor, not just a single element, and to + // represent this we need to use For stmt: let's wrap our store stmt with + // two nested loops to represent that variables i and j need to iterate + // over some ranges. + ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); + ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); + + std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; // Prints: - // Tensor computation: Tensor X[64, 32]: + // Nested for loops: // for (int i = 0; i < 64; i++) { // for (int j = 0; j < 32; j++) { - // X[i, j] = i * j; + // A[i, j] = i + j; // } // } - // TODO: Add an example of constructing a Tensor with a complete Stmt. - - // Similarly to how we provide a more convenient way of using handles for - // constructing Exprs, Tensors also have a more convenient API for - // construction. It is based on Compute API, which takes a name, - // dimensions, and a lambda specifying the computation body: - Tensor* Z = Compute( - "Z", - {{64, "i"}, {32, "j"}}, - [](const VarHandle& i, const VarHandle& j) { return i / j; }); - std::cout << "Tensor computation: " << *Z << std::endl; + // A Block statement is used when we need a sequence of other statements. + // E.g. if a fusion group contains several operators, we initially define + // separate loopnest for each of them and put them all into a common block: + BufHandle B("B", {64, 32}, kInt); + StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); + ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); + ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); + + BlockPtr block = Block::make({loop_i_a, loop_i_b}); + std::cout << "Compound Block statement: " << std::endl + << *block << std::endl; // Prints: - // Tensor computation: Tensor Z[64, 32]: - // for (int i = 0; i < 64; i++) { - // for (int j = 0; j < 32; j++) { - // Z[i, j] = i / j; + // Compound Block statement: + // { + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // A[i, j] = i + j; + // } + // } + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // B[i, j] = A[i, j]; + // } // } // } - // Tensors might access other tensors and external placeholders in their - // expressions. It can be done like so: - Placeholder P("P", kInt, {64, 32}); - Tensor* R = Compute( - "R", + // Manually constructing nested loops and blocks to represent a computation + // might be laborious, and instead we can use a 'Compute' API. This API + // requires us to specify dimensions and a lambda to compute a single + // element of the resulting tensor and returns a `Tensor` structure. This + // structure is simply a pair of a buffer that was created to represent the + // result of the computation (BufPtr) and a statement representing the + // computation itself (StmtPtr). + Tensor C = Compute( + "C", {{64, "i"}, {32, "j"}}, - [&](const VarHandle& i, const VarHandle& j) { - return Z->load(i, j) * P.load(i, j); - }); - std::cout << "Tensor computation: " << *R << std::endl; + [&](const VarHandle& i, const VarHandle& j) { return i * j; }); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *C.stmt() << std::endl; // Prints: - // Tensor computation: Tensor R[64, 32]: + // Stmt produced by 'Compute' API: // for (int i = 0; i < 64; i++) { // for (int j = 0; j < 32; j++) { - // R[i, j] = (Z(i, j)) * (P[i, j]); + // C[i, j] = i * j; // } // } - // Placeholders could be thought of as external tensors, i.e. tensors for - // which we don't have the element expression. In other words, for `Tensor` - // we know an expression specifying how its elements can be computed (a - // mathematical formula). For external tensors, or placeholders, we don't - // have such an expression. They need to be considered as coming to us as - // inputs from outside - we can only load data from them. - // - // TODO: Show how reductions are represented and constructed + // To construct statements to represent computations with reductions, we + // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple + // of extra arguments defining how to perform the reduction. Let's define a + // simple 2D sum of C using that: + Tensor D = Reduce( + "D", + {}, + Sum(), + [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, + {{64, "i"}, {32, "j"}}); + std::cout << "Stmt produced by 'Reduce' API: " << std::endl + << *D.stmt() << std::endl; } - std::cout << "*** Loopnests and Statements ***" << std::endl; + std::cout << "*** Loopnests transformations ***" << std::endl; { - // Creating a tensor expression is the first step to generate an executable - // code for it. A next step is to represent it as a loop nest and apply - // various loop transformations in order to get an optimal implementation. - // In Halide's or TVM's terms the first step was to define the algorithm of - // computation (what to compute?) and now we are getting to the schedule of - // the computation (how to compute?). + // When a statement for the computation is generated, we might want to + // apply some optimizations to it. These transformations allow us to end up + // with a statement producing the same results, but more efficiently. // - // Let's create a simple tensor expression and construct a loop nest for it. - Placeholder A("A", kFloat, {64, 32}); - Placeholder B("B", kFloat, {64, 32}); - Tensor* X = Compute( - "X", + // Let's look at a couple of transformations that are used in NNC. We will + // begin with constructing a Block statement like we did before. + + Tensor C = Compute( + "C", {{64, "i"}, {32, "j"}}, - [&](const VarHandle& i, const VarHandle& j) { - return A.load(i, j) + B.load(i, j); - }); - Tensor* Y = Compute( - "Y", + [&](const VarHandle& i, const VarHandle& j) { return i * (j + 1); }); + BufHandle c_buf(C.buf()); + Tensor D = Compute( + "D", {{64, "i"}, {32, "j"}}, [&](const VarHandle& i, const VarHandle& j) { - return sigmoid(X->load(i, j)); + return c_buf.load(i, j) - i; }); - std::cout << "Tensor computation X: " << *X - << "Tensor computation Y: " << *Y << std::endl; + StmtPtr block = Block::make({C.stmt(), D.stmt()}); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *block << std::endl; // Prints: - // Tensor computation X: Tensor X[64, 32]: - // for (int i = 0; i < 64; i++) { - // for (int j = 0; j < 32; j++) { - // X[i, j] = (A[i, j]) + (B[i, j]); + // Stmt produced by 'Compute' API: + // { + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // C[i, j] = i * (j + 1); + // } // } - // } - - // Tensor computation Y: Tensor Y[64, 32]: - // for (int i = 0; i < 64; i++) { - // for (int j = 0; j < 32; j++) { - // Y[i, j] = sigmoid(X(i, j)); + // for (int i_1 = 0; i_1 < 64; i_1++) { + // for (int j_1 = 0; j_1 < 32; j_1++) { + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; + // } // } // } - // Creating a loop nest is as quite simple, we just need to specify a list - // of all and a list of output tensors: - // NOLINTNEXTLINE(bugprone-argument-comment) - LoopNest loopnest(/*outputs=*/{Y}, /*all=*/{X, Y}); - - // An IR used in LoopNest is based on tensor statements, represented by - // `Stmt` class. Statements are used to specify the loop nest structure, and - // to take a sneak peek at them, let's print out what we got right after - // creating our LoopNest object: - std::cout << *loopnest.root_stmt() << std::endl; + // One transformation we can apply to this computation is inlining: i.e. + // taking the expression that defines values of C and substituting a load + // from C with it. + // To do that, we first need to create a special object called LoopNest - + // all transformations are methods of this class. To create a loopnest we + // need to provide a list of output buffers and the root statement: + LoopNest nest(block, {D.buf()}); + + // We can always retrieve the Stmt back from LoopNest: + std::cout << "LoopNest root stmt: " << std::endl + << *nest.root_stmt() << std::endl; // Prints: + // LoopNest root stmt: // { // for (int i = 0; i < 64; i++) { // for (int j = 0; j < 32; j++) { - // X[i, j] = (A[i, j]) + (B[i, j]); + // C[i, j] = i * (j + 1); // } // } // for (int i_1 = 0; i_1 < 64; i_1++) { // for (int j_1 = 0; j_1 < 32; j_1++) { - // Y[i_1, j_1] = sigmoid(X(i_1, j_1)); + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; // } // } // } - // To introduce statements let's first look at their three main types (in - // fact, there are more than 3 types, but the other types would be easy to - // understand once the overall structure is clear): - // 1) Block - // 2) For - // 3) Store - // - // A `Block` statement is simply a list of other statements. - // A `For` is a statement representing one axis of computation. It contains - // an index variable (Var), boundaries of the axis (start and end - both are - // `Expr`s), and a `Block` statement body. - // A `Store` represents an assignment to a tensor element. It contains a Buf - // representing the target tensor, a list of expressions for indices of the - // element, and the value to be stored, which is an arbitrary expression. - - // Once we've constructed the loop nest, we can apply various tranformations - // to it. To begin with, let's inline computation of X into computation of Y - // and see what happens to our statements. - loopnest.computeInline(loopnest.getLoopBodyFor(X)); - std::cout << *loopnest.root_stmt() << std::endl; + // Now we can apply the inlining transformation: + nest.computeInline(C.buf()); + std::cout << "Stmt after inlining:" << std::endl + << *nest.root_stmt() << std::endl; // Prints: + // Stmt after inlining: // { // for (int i = 0; i < 64; i++) { // for (int j = 0; j < 32; j++) { - // Y[i, j] = sigmoid((A[i, j]) + (B[i, j])); + // D[i, j] = i * (j + 1) - i; // } // } // } - // - // As you can see, the first two loops have disappeared and the expression - // for X[i,j] has been inserted into the Y[i,j] computation. - - // Loop transformations can be composed, so we can do something else with - // our loop nest now. Let's split the inner loop with a factor of 9, for - // instance. - std::vector loops = loopnest.getLoopStmtsFor(Y); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ForPtr j_inner; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ForPtr j_tail; - int split_factor = 9; - loopnest.splitWithTail( - loops[1], // loops[0] is the outer loop, loops[1] is inner - split_factor, - &j_inner, // further transformations - &j_tail); - // loops[1] will become the outer loop, j_outer, after splitWithTail. - std::cout << *loopnest.root_stmt() << std::endl; + + // We can also apply algebraic simplification to a statement: + StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); + std::cout << "Stmt after simplification:" << std::endl + << *simplified << std::endl; // Prints: + // Stmt after simplification: // { // for (int i = 0; i < 64; i++) { - // for (int j_outer = 0; j_outer < (32 - 0) / 9; j_outer++) { - // for (int j_inner = 0; j_inner < 9; j_inner++) { - // Y[i, j_outer * 9 + j_inner] = sigmoid((A[i, j_outer * 9 + ... + // for (int j = 0; j < 32; j++) { + // D[i, j] = i * j; + // } + // } + // } + + // Many loopnest transformations are stateless and can be applied without + // creating a LoopNest object. In fact, we plan to make all transformations + // stateless. + // splitWithTail is one such transformation: it splits an iteration space + // of a given loop into two with a given factor. + ForPtr outer_loop = to(to(simplified)->stmts().front()); + LoopNest::splitWithTail(outer_loop, 13); + // Call simplifier once more to fold some arithmetic. + simplified = IRSimplifier::simplify(simplified); + std::cout << "Stmt after splitWithTail:" << std::endl + << *simplified << std::endl; + // Prints: + // Stmt after splitWithTail: + // { + // for (int i_outer = 0; i_outer < 4; i_outer++) { + // for (int i_inner = 0; i_inner < 13; i_inner++) { + // for (int j = 0; j < 32; j++) { + // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); // } // } - // for (int j_tail = 0; j_tail < (32 - 0) % 9; j_tail++) { - // Y[i, j_tail + ((32 - 0) / 9) * 9] = sigmoid((A[i, j_tail + ... + // } + // for (int i_tail = 0; i_tail < 12; i_tail++) { + // for (int j = 0; j < 32; j++) { + // D[i_tail + 52, j] = i_tail * j + 52 * j; // } // } // } - // TODO: List all available transformations - // TODO: Show how statements can be constructed manually + // NNC supports a wide range of loop nest transformations, which we are not + // listing here. Please refer to documentation in + // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h + // for more details. } std::cout << "*** Codegen ***" << std::endl; @@ -348,21 +344,23 @@ int main(int argc, char* argv[]) { // An ultimate goal of tensor expressions is to be provide a mechanism to // execute a given computation in the fastest possible way. So far we've // looked at how we could describe what computation we're interested in, but - // we haven't looked at how to actually execute it. So far all we've been - // dealing with was just symbols with no actual data associated, in this - // section we would look at how we can bridge that gap. + // we haven't looked at how to actually execute it. + // + // All we've been dealing with was just symbols with no actual data + // associated, in this section we would look at how we can bridge that gap. // Let's start by constructing a simple computation for us to work with: - Placeholder A("A", kInt, {64, 32}); - Placeholder B("B", kInt, {64, 32}); - Tensor* X = Compute( + BufHandle A("A", {64, 32}, kInt); + BufHandle B("B", {64, 32}, kInt); + Tensor X = Compute( "X", {{64, "i"}, {32, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return A.load(i, j) + B.load(i, j); }); - // And let's lower it to a loop nest, as we did in the previous section: + // And let's lower it to a loop nest, as we did in the previous section. We + // can pass Tensor object directly: LoopNest loopnest({X}); std::cout << *loopnest.root_stmt() << std::endl; // Prints: @@ -429,6 +427,115 @@ int main(int argc, char* argv[]) { // X[10] = A[10] + B[10] = 8 } - // TODO: Show how TorchScript IR is translated to TE + std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; + { + // This section requires a LLVM-enabled PyTorch build, so we have to use a + // guard: +#ifdef TORCH_ENABLE_LLVM + + // Often we would like to convert a TorchScript IR to TE rather than + // construct TE IR from scratch. NNC provides an API to perform such + // lowering: it takes a TorchScript graph and returns an object that can be + // used to invoke the generated kernel. + // This API is currently used by the TorchScript JIT fuser and can also be + // used ahead of time to pre-compile parts of a model. + // + // To get familiar with this API let's first start with defining a simple + // TorchScript graph: + const auto graph_string = R"IR( + graph(%A : Float(5, 3, strides=[3, 1], device=cpu), + %B : Float(5, 3, strides=[3, 1], device=cpu)): + %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) + %one : int = prim::Constant[value=1]() + %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) + %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) + return (%AAB_plus_B))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + // This graph defines a simple computation of A*A*B + B where A and B are + // input 5x3 tensors. + + // To lower this TorchScript graph to TE, we just need to create a + // TensorExprKernel object. In its constructor it constructs the + // corresponding TE IR and compiles it for the given backend (in this + // example for CPU using LLVM compiler). + TensorExprKernel kernel(graph); + + // We can retrieve the generated TE stmt from the kernel object: + StmtPtr kernel_stmt = kernel.getCodeGenStmt(); + std::cout << "TE Stmt constructed from TorchScript: " << std::endl + << *kernel_stmt << std::endl; + // Prints: + // TE Stmt constructed from TorchScript: + // { + // for (int v = 0; v < 5; v++) { + // for (int _tail_tail = 0; _tail_tail < 3; _tail_tail++) { + // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * + // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + + // (tB[_tail_tail + 3 * v]); + // } + // } + // } + + // We can also examine generated LLVM IR and assembly code: + std::cout << "Generated LLVM IR: " << std::endl; + auto ir_str = kernel.getCodeText("ir"); + printLinesToFrom(ir_str, 15, 20); + // Prints: + // Generated LLVM IR: + // %9 = bitcast float* %2 to <8 x float>* + // %10 = load <8 x float>, <8 x float>* %9 ... + // %11 = bitcast float* %5 to <8 x float>* + // %12 = load <8 x float>, <8 x float>* %11 ... + // %13 = fmul <8 x float> %10, %12 + // %14 = fmul <8 x float> %10, %13 + + std::cout << "Generated assembly: " << std::endl; + auto asm_str = kernel.getCodeText("asm"); + printLinesToFrom(asm_str, 10, 15); + // Prints: + // Generated assembly: + // vmulps %ymm1, %ymm0, %ymm2 + // vfmadd213ps %ymm1, %ymm0, %ymm2 + // vmovups %ymm2, (%rax) + // vmovss 32(%rcx), %xmm0 + // vmovss 32(%rdx), %xmm1 + // vmulss %xmm1, %xmm0, %xmm2 + + // We can also execute the generated kernel: + auto A = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 2.0; + auto B = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 3.0; + std::vector inputs = {A, B}; + std::vector stack = torch::fmap(inputs); + kernel.run(stack); + auto R = stack[0].toTensor(); + + // Let's print one of the elements from the result tensor to verify that the + // computation did happen and was correct: + std::cout << "R[2][2] = " << R[2][2] << std::endl; + // Prints: + // R[2][2] = 15 + // [ CPUFloatType{} ] +#endif + } return 0; } + +void printLinesToFrom(const std::string& input_str, int from, int to) { + std::istringstream f(input_str); + std::string s; + int idx = 0; + while (getline(f, s)) { + if (idx > from) { + std::cout << s << "\n"; + } + if (idx++ > to) { + break; + } + } +} diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md index 869ef300f6c85..88e1848f7da78 100644 --- a/test/cpp_api_parity/parity-tracker.md +++ b/test/cpp_api_parity/parity-tracker.md @@ -99,6 +99,7 @@ torch::nn::Identity|Yes|No torch::nn::Linear|Yes|No torch::nn::Bilinear|Yes|No torch::nn::Flatten|Yes|No +torch::nn::Unflatten|Yes|No torch::nn::Dropout|Yes|No torch::nn::Dropout2d|Yes|No torch::nn::Dropout3d|Yes|No diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/ort_extension.cpp similarity index 78% rename from test/cpp_extensions/msnpu_extension.cpp rename to test/cpp_extensions/ort_extension.cpp index e47347c40fbfa..b646f3b14939d 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/ort_extension.cpp @@ -10,10 +10,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { Storage( Storage::use_byte_size_t(), 0, - at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), + at::DataPtr(nullptr, Device(DeviceType::ORT, 0)), nullptr, false), - DispatchKey::MSNPU, + DispatchKey::ORT, dtype); // This is a hack to workaround the shape checks in _convolution. tensor_impl->set_sizes_contiguous(size); @@ -52,7 +52,7 @@ std::tuple fake_convolution_backward( get_tensor(input.dtype(), {})); } -TORCH_LIBRARY_IMPL(aten, MSNPU, m) { +TORCH_LIBRARY_IMPL(aten, ORT, m) { m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); @@ -61,34 +61,34 @@ TORCH_LIBRARY_IMPL(aten, MSNPU, m) { // TODO: Extend this to exercise multi-device setting. In that case, // we need to add a thread local variable to track the current device. -struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { - static constexpr DeviceType static_type = DeviceType::MSNPU; - MSNPUGuardImpl() {} - MSNPUGuardImpl(DeviceType t) { - AT_ASSERT(t == DeviceType::MSNPU); +struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::ORT; + ORTGuardImpl() {} + ORTGuardImpl(DeviceType t) { + AT_ASSERT(t == DeviceType::ORT); } DeviceType type() const override { - return DeviceType::MSNPU; + return DeviceType::ORT; } Device exchangeDevice(Device d) const override { - AT_ASSERT(d.type() == DeviceType::MSNPU); + AT_ASSERT(d.type() == DeviceType::ORT); AT_ASSERT(d.index() == 0); return d; } Device getDevice() const override { - return Device(DeviceType::MSNPU, 0); + return Device(DeviceType::ORT, 0); } void setDevice(Device d) const override { - AT_ASSERT(d.type() == DeviceType::MSNPU); + AT_ASSERT(d.type() == DeviceType::ORT); AT_ASSERT(d.index() == 0); } void uncheckedSetDevice(Device d) const noexcept override { } Stream getStream(Device d) const noexcept override { - return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0)); + return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0)); } Stream exchangeStream(Stream s) const noexcept override { - return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0)); + return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0)); } DeviceIndex deviceCount() const noexcept override { return 1; @@ -99,23 +99,23 @@ struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { - TORCH_CHECK(false, "MSNPU backend doesn't support events."); + TORCH_CHECK(false, "ORT backend doesn't support events."); } void block( void* event, const Stream& stream) const override { - TORCH_CHECK(false, "MSNPU backend doesn't support events."); + TORCH_CHECK(false, "ORT backend doesn't support events."); } bool queryEvent(void* event) const override { - TORCH_CHECK(false, "MSNPU backend doesn't support events."); + TORCH_CHECK(false, "ORT backend doesn't support events."); } void destroyEvent( void* event, const DeviceIndex device_index) const noexcept override { } }; -constexpr DeviceType MSNPUGuardImpl::static_type; -C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl); +constexpr DeviceType ORTGuardImpl::static_type; +C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl); int get_test_int() { return test_int; diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index 8f77938ae3226..7888d0e3a88bb 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -21,7 +21,7 @@ 'torch_test_cpp_extension.cpp', ['extension.cpp'], extra_compile_args=CXX_FLAGS), CppExtension( - 'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'], + 'torch_test_cpp_extension.ort', ['ort_extension.cpp'], extra_compile_args=CXX_FLAGS), CppExtension( 'torch_test_cpp_extension.rng', ['rng_extension.cpp'], diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 7c6a187df1465..ec22568c5a3ea 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -30,7 +30,7 @@ Result get_operator_from_registry_and_execute(const char* op_name, Args&&... arg torch::jit::Stack stack; torch::jit::push(stack, std::forward(args)...); - op->getOperation()(&stack); + op->getOperation()(stack); TORCH_INTERNAL_ASSERT(1 == stack.size()); return torch::jit::pop(stack).to(); diff --git a/test/custom_operator/test_custom_ops.py b/test/custom_operator/test_custom_ops.py index 3937abde91476..356b4932d49ac 100644 --- a/test/custom_operator/test_custom_ops.py +++ b/test/custom_operator/test_custom_ops.py @@ -44,8 +44,8 @@ def test_calling_custom_op_with_autograd(self): output.sum().backward(go, False, True) grad = torch.ones(5, 5) - self.assertTrue(torch.allclose(x.grad, y + grad)) - self.assertTrue(torch.allclose(y.grad, x + grad * 2)) + self.assertEqual(x.grad, y + grad) + self.assertEqual(y.grad, x + grad * 2) # Test with optional arg. x.grad.zero_() @@ -56,9 +56,9 @@ def test_calling_custom_op_with_autograd(self): go = torch.ones((), requires_grad=True) output.sum().backward(go, False, True) - self.assertTrue(torch.allclose(x.grad, y + grad)) - self.assertTrue(torch.allclose(y.grad, x + grad * 2)) - self.assertTrue(torch.allclose(z.grad, grad)) + self.assertEqual(x.grad, y + grad) + self.assertEqual(y.grad, x + grad * 2) + self.assertEqual(z.grad, grad) def test_calling_custom_op_with_autograd_in_nograd_mode(self): with torch.no_grad(): diff --git a/test/distributed/_sharded_tensor/test_sharded_tensor.py b/test/distributed/_sharded_tensor/test_sharded_tensor.py index 829855f6be2c5..77e35b76f3731 100644 --- a/test/distributed/_sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_sharded_tensor/test_sharded_tensor.py @@ -1,5 +1,8 @@ from functools import wraps +import math import io +import itertools +import pickle import sys import torch import torch.distributed as dist @@ -15,6 +18,12 @@ EnumerableShardingSpec, ShardMetadata ) +from torch.distributed._sharded_tensor.api import ( + CreateOp, + TensorInitParams, + TensorProperties, + _create_tensor_from_params, +) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, requires_nccl, @@ -22,10 +31,11 @@ TEST_SKIPS, ) from torch.testing._internal.common_utils import ( + TestCase, TEST_WITH_DEV_DBG_ASAN, run_tests, + sandcastle_skip_if, ) - if TEST_WITH_DEV_DBG_ASAN: print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) sys.exit(0) @@ -115,6 +125,179 @@ def wrapper(self): self.destroy_comms() return wrapper +class TestShardedTensorMetadata(TestCase): + def test_serialize_and_deserialize(self): + shard_metadatas = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_lengths=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_lengths=[5, 5], + placement="rank:3/cuda:3", + ) + ] + + dtypes = [ + torch.float, torch.double, torch.cfloat, torch.cdouble, torch.half, + torch.bfloat16, torch.uint8, torch.int8, torch.short, torch.int, + torch.long, torch.bool] + + layouts = [torch.strided, torch.sparse_coo] + requires_grads = [True, False] + memory_formats = [torch.contiguous_format, torch.channels_last, torch.preserve_format] + pin_memories = [True, False] + + for tensor_properties_input in itertools.product(dtypes, layouts, requires_grads, memory_formats, pin_memories): + dtype, layout, requires_grad, memory_format, pin_memory = tensor_properties_input + + expected_st_metadata = _sharded_tensor.ShardedTensorMetadata( + shard_metadatas, + (10, 10), + _sharded_tensor.TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory) + ) + + pickled_obj = pickle.dumps(expected_st_metadata) + st_metadata = pickle.loads(pickled_obj) + self.assertEqual(expected_st_metadata, st_metadata) + +class TestCreateTensorFromParams(TestCase): + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_empty(self): + expected_dtype = torch.double + tensor_properties = TensorProperties( + dtype=expected_dtype, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, ) + tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, + tensor_properties=tensor_properties) + local_device = torch.device('cuda:0') + local_tensor = _create_tensor_from_params( + 5, 10, local_device=local_device, tensor_init_params=tensor_init_params) + self.assertEqual(local_device, local_tensor.device) + self.assertEqual(expected_dtype, local_tensor.dtype) + self.assertEqual(torch.strided, local_tensor.layout) + self.assertEqual(False, local_tensor.requires_grad) + + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_ones(self): + expected_dtype = torch.double + tensor_properties = TensorProperties( + dtype=expected_dtype, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, ) + tensor_init_params = TensorInitParams( + create_op=CreateOp.ONES, tensor_properties=tensor_properties) + local_device = torch.device('cuda:0') + h, w = 5, 10 + local_tensor = _create_tensor_from_params( + h, w, local_device=local_device, tensor_init_params=tensor_init_params) + expected_tensor = torch.ones(h, w, device=local_device, dtype=expected_dtype) + self.assertEqual(expected_tensor, local_tensor) + + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_zeros(self): + expected_dtype = torch.int32 + tensor_properties = TensorProperties( + dtype=expected_dtype, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, ) + local_device = torch.device('cuda:0') + h, w = 5, 10 + local_tensor = _create_tensor_from_params( + h, w, local_device=local_device, tensor_init_params=tensor_init_params) + expected_tensor = torch.zeros(h, w, device=local_device, dtype=expected_dtype) + self.assertEqual(expected_tensor, local_tensor) + + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_rand(self): + expected_dtype = torch.double + tensor_properties = TensorProperties( + dtype=expected_dtype, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, ) + local_device = torch.device('cuda:0') + h, w = 5, 10 + seed = 13 + torch.cuda.manual_seed(seed) + local_tensor = _create_tensor_from_params( + h, w, local_device=local_device, tensor_init_params=tensor_init_params) + # reset seed to ensure same random numbers are generated + torch.cuda.manual_seed(seed) + expected_tensor = torch.rand(h, w, device=local_device, dtype=expected_dtype) + self.assertEqual(expected_tensor, local_tensor) + + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_full_with_dtype_inferred(self): + fill_value = 23.5 + tensor_properties = TensorProperties( + # tensor's dtype can be inferred from fill_value + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ) + tensor_init_params = TensorInitParams( + create_op=CreateOp.FULL, + fill_value=fill_value, + tensor_properties=tensor_properties, ) + local_device = torch.device('cuda:0') + h, w = 5, 10 + local_tensor = _create_tensor_from_params( + h, w, local_device=local_device, tensor_init_params=tensor_init_params) + # local_tensor.dtype is inferred from fill_value (float32). + self.assertEqual(torch.float32, local_tensor.dtype) + expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device) + self.assertEqual(expected_tensor, local_tensor) + + @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed') + def test_full_with_dtype_overridden(self): + fill_value = 23.5 + tensor_properties = TensorProperties( + # tensor's dtype can be inferred from fill_value + dtype=torch.double, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ) + tensor_init_params = TensorInitParams( + create_op=CreateOp.FULL, + fill_value=fill_value, + tensor_properties=tensor_properties, ) + local_device = torch.device('cuda:0') + h, w = 5, 10 + local_tensor = _create_tensor_from_params( + h, w, local_device=local_device, tensor_init_params=tensor_init_params) + # local_tensor.dtype is overridden. + self.assertEqual(torch.double, local_tensor.dtype) + expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device, dtype=torch.double) + self.assertEqual(expected_tensor, local_tensor) class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase): @@ -135,19 +318,17 @@ def test_sharded_tensor_metadata(self): sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size) - self.assertEqual(torch.float, sharded_tensor_metadata.dtype) - self.assertEqual(torch.strided, sharded_tensor_metadata.layout) - self.assertEqual(False, sharded_tensor_metadata.requires_grad) - self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format) - self.assertEqual(False, sharded_tensor_metadata.pin_memory) + self.assertEqual(torch.float, sharded_tensor.dtype) + self.assertEqual(torch.strided, sharded_tensor.layout) + self.assertEqual(False, sharded_tensor.requires_grad) + self.assertTrue(sharded_tensor.is_contiguous()) + self.assertFalse(sharded_tensor.is_pinned()) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(True, sharded_tensor_metadata.requires_grad) + self.assertEqual(True, sharded_tensor.requires_grad) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(torch.double, sharded_tensor_metadata.dtype) + self.assertEqual(torch.double, sharded_tensor.dtype) # Need CPU for pin_memory spec = ChunkShardingSpec( @@ -161,8 +342,12 @@ def test_sharded_tensor_metadata(self): ) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(True, sharded_tensor_metadata.pin_memory) + self.assertEqual(True, sharded_tensor.is_pinned()) + + # test read only properties, they're read only as we can't simply change + # the global metadata without changing the underlying shard's properties + with self.assertRaisesRegex(AttributeError, "can't set attribute"): + sharded_tensor.requires_grad = True @with_comms @skip_if_lt_x_gpu(4) @@ -219,6 +404,131 @@ def test_complete_world_size(self): else: self.assertEqual((3, 20), shard.tensor.size()) + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_create_sharded_tensor_with_ones(self): + """ Test _sharded_tensor.ones(...) """ + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + sharded_tensor = _sharded_tensor.ones(spec, h, w) + + # Validate local shard is initialized with torch.ones + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 + expected_h = 1 if self.rank == 3 else math.ceil(h / 4) + self.assertEqual((expected_h, w), local_shard.size()) + self.assertEqual(local_shard, torch.ones(expected_h, w)) + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_create_sharded_tensor_with_zeros(self): + """ Test _sharded_tensor.zeros(...) """ + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + sharded_tensor = _sharded_tensor.zeros(spec, h, w) + + # Validate local shard is initialized with torch.zeros + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 + expected_h = 1 if self.rank == 3 else math.ceil(h / 4) + self.assertEqual((expected_h, w), local_shard.size()) + self.assertEqual(local_shard, torch.zeros(expected_h, w)) + + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_create_sharded_tensor_with_rand(self): + """ Test _sharded_tensor.rand(...) """ + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 8, 2 + seed = 1234 + + expected_h = 2 + expected_device = torch.device(f"cuda:{self.rank}") + dtype = torch.double + torch.manual_seed(seed) + expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype) + # reset seed to ensure the same random numbers are generated + torch.manual_seed(seed) + sharded_tensor = _sharded_tensor.rand(spec, h, w, dtype=dtype) + + # Validate local shard is initialized with torch.rand + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(expected_device, local_shard.device) + self.assertEqual((expected_h, w), local_shard.size()) + self.assertEqual(expected, local_shard) + + + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_create_sharded_tensor_with_full(self): + """ Test _sharded_tensor.full(...) """ + + spec = ChunkShardingSpec( + dim=0, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + h, w = 10, 20 + fill_value = 1234 + sharded_tensor = _sharded_tensor.full(spec, size=(h, w), fill_value=fill_value, dtype=torch.int32) + + # Validate local shard is initialized with torch.full + local_shards = sharded_tensor.local_shards() + self.assertEqual(1, len(local_shards)) + local_shard = local_shards[0].tensor + self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 + expected_h = 1 if self.rank == 3 else math.ceil(h / 4) + self.assertEqual((expected_h, w), local_shard.size()) + self.assertEqual(local_shard, + torch.full(size=(expected_h, w), fill_value=fill_value, dtype=torch.int32)) + + @with_comms @skip_if_lt_x_gpu(4) @requires_nccl() @@ -714,19 +1024,17 @@ def test_sharded_tensor_metadata(self): sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size) - self.assertEqual(torch.float, sharded_tensor_metadata.dtype) - self.assertEqual(torch.strided, sharded_tensor_metadata.layout) - self.assertEqual(False, sharded_tensor_metadata.requires_grad) - self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format) - self.assertEqual(False, sharded_tensor_metadata.pin_memory) + self.assertEqual(torch.float, sharded_tensor.dtype) + self.assertEqual(torch.strided, sharded_tensor.layout) + self.assertEqual(False, sharded_tensor.requires_grad) + self.assertTrue(sharded_tensor.is_contiguous()) + self.assertFalse(sharded_tensor.is_pinned()) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(True, sharded_tensor_metadata.requires_grad) + self.assertEqual(True, sharded_tensor.requires_grad) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(torch.double, sharded_tensor_metadata.dtype) + self.assertEqual(torch.double, sharded_tensor.dtype) # Need CPU for pin_memory spec = EnumerableShardingSpec([ @@ -753,8 +1061,7 @@ def test_sharded_tensor_metadata(self): ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(True, sharded_tensor_metadata.pin_memory) + self.assertTrue(sharded_tensor.is_pinned()) @with_comms @skip_if_lt_x_gpu(4) @@ -818,6 +1125,45 @@ def test_grid_sharding(self): shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size()) + @with_comms + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_create_sharded_tensor_with_ones(self): + """ Test _sharded_tensor.ones(...) """ + + spec = EnumerableShardingSpec([ + ShardMetadata( + shard_offsets=[0, 0], + shard_lengths=[5, 5], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_lengths=[5, 5], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_lengths=[5, 5], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_lengths=[5, 5], + placement="rank:3/cuda:3", + ) + ]) + + sharded_tensor = _sharded_tensor.ones(spec, 10, 10, init_rrefs=True) + self.assertEqual((10, 10), sharded_tensor.size()) + self.assertEqual(1, len(sharded_tensor.local_shards())) + + # Verify local shard is initialized with torch.ones + local_shard = sharded_tensor.local_shards()[0] + self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) + self.assertEqual((5, 5), local_shard.tensor.size()) + self.assertEqual(local_shard.tensor, torch.ones(5, 5)) + @skip_if_lt_x_gpu(4) @requires_nccl() def test_uneven_shards(self): @@ -1161,15 +1507,18 @@ def test_init_from_local_shards(self): local_shards = [_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( - shards_metadata=shards_metadata, - size=torch.Size([10, 10]), + tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) + sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 10]), + tensor_properties=tensor_properties, + ) sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, sharded_tensor_metadata, init_rrefs=True) self.assertEqual((10, 10), sharded_tensor.size()) @@ -1228,15 +1577,19 @@ def test_init_from_local_shards_new_group(self): local_shard_metadata = rank1_shard_metadata if self.rank == 1 else rank3_shard_metadata local_shards.append(_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( - shards_metadata=shards_metadata, - size=torch.Size([10, 5]), + tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) + + sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 5]), + tensor_properties=tensor_properties + ) sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, sharded_tensor_metadata, new_pg, init_rrefs=True) if self.rank == 1 or self.rank == 3: @@ -1297,15 +1650,18 @@ def test_init_from_local_shards_invalid_shards(self): placement=f"rank:{r}/cuda:{r}" )) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( - shards_metadata=shards_metadata, - size=torch.Size([10, 10]), + tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) + sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 10]), + tensor_properties=tensor_properties + ) empty_local_shards = [] with self.assertRaisesRegex(RuntimeError, 'does not match number of local shards metadata'): @@ -1329,7 +1685,7 @@ def test_init_from_local_shards_invalid_shards(self): wrong_dtype_shards = [ _sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), local_shard_metadata) ] - with self.assertRaisesRegex(ValueError, 'Local shard tensor dtype does not match with sharded_tensor_metadata'): + with self.assertRaisesRegex(ValueError, 'Local shard tensor dtype does not match with tensor_properties!'): sharded_tensor = _sharded_tensor.init_from_local_shards(wrong_dtype_shards, sharded_tensor_metadata, init_rrefs=True) indices = [[0, 1, 1], [2, 0, 2]] @@ -1339,21 +1695,21 @@ def test_init_from_local_shards_invalid_shards(self): wrong_layout_shards = [ _sharded_tensor.Shard(sparse_tensor, local_shard_metadata) ] - with self.assertRaisesRegex(ValueError, 'Local shard tensor layout does not match with sharded_tensor_metadata'): + with self.assertRaisesRegex(ValueError, 'Local shard tensor layout does not match with tensor_properties!'): sharded_tensor = _sharded_tensor.init_from_local_shards( wrong_layout_shards, sharded_tensor_metadata, init_rrefs=True) wrong_requires_grad_shards = [ _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), local_shard_metadata) ] - with self.assertRaisesRegex(ValueError, 'Local shard tensor requires_grad does not match with sharded_tensor_metadata'): + with self.assertRaisesRegex(ValueError, 'Local shard tensor requires_grad does not match with tensor_properties!'): sharded_tensor = _sharded_tensor.init_from_local_shards( wrong_requires_grad_shards, sharded_tensor_metadata, init_rrefs=True) wrong_pin_memory_shards = [ _sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata) ] - with self.assertRaisesRegex(ValueError, 'Local shard tensor pin_memory does not match with sharded_tensor_metadata'): + with self.assertRaisesRegex(ValueError, 'Local shard tensor pin_memory does not match with tensor_properties!'): sharded_tensor = _sharded_tensor.init_from_local_shards( wrong_pin_memory_shards, sharded_tensor_metadata, init_rrefs=True) @@ -1386,15 +1742,18 @@ def test_init_from_local_shards_invalid_shards_overlap(self): placement=f"rank:{r}/cuda:{r}" )) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( - shards_metadata=shards_metadata, - size=torch.Size([10, 10]), + tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) + sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 10]), + tensor_properties=tensor_properties + ) local_shard_size = (5, 5) if self.rank != 0 else (6, 6) @@ -1425,15 +1784,18 @@ def test_init_from_local_shards_invalid_shards_gaps(self): placement=f"rank:{r}/cuda:{r}" )) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( - shards_metadata=shards_metadata, - size=torch.Size([10, 10]), + tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False, ) + sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=torch.Size([10, 10]), + tensor_properties=tensor_properties + ) local_shard_size = (5, 5) if self.rank != 0 else (4, 4) diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index 1f78d50b604e8..67175b2d22495 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -2,7 +2,6 @@ import os import sys -import numpy as np import torch from torch import nn import torch.distributed as dist @@ -21,8 +20,14 @@ requires_nccl, skip_if_lt_x_gpu, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import ( + run_tests, + TEST_WITH_DEV_DBG_ASAN, +) +if TEST_WITH_DEV_DBG_ASAN: + print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr) + sys.exit(0) def gpus_for_rank(world_size): visible_devices = list(range(torch.cuda.device_count())) @@ -57,7 +62,7 @@ def forward(self, x, rank): class DistributedDataParallelCommHookTest(MultiProcessTestCase): def setUp(self): super(DistributedDataParallelCommHookTest, self).setUp() - self._fork_processes() + self._spawn_processes() def tearDown(self): try: @@ -99,7 +104,9 @@ def _run_and_get_grads(self, model): # Run backward output.mean().backward() - return [p.grad.data.cpu().numpy() for p in model.parameters()] + # The only layer + param = next(model.parameters()) + return param.grad @requires_nccl() @skip_if_lt_x_gpu(2) @@ -116,7 +123,7 @@ def test_ddp_comm_hook_allreduce_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -133,7 +140,7 @@ def test_ddp_comm_hook_fp16compress_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -150,7 +157,7 @@ def test_ddp_comm_hook_quantize_per_tensor_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -169,7 +176,58 @@ def test_ddp_comm_hook_quantize_per_channel_hook(self): process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL ) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_noop_hook(self): + """ + This unit test verifies the ``noop`` hook registered case and a subsequent allreduce + gives same result with no hook registered case. + """ + store = dist.FileStore(self.file_name, self.world_size) + process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads(process_group, DDPCommHookType.NOOP) + # Apply a subsequent allreduce to average grads. + hook_grads.div_(self.world_size) + dist.all_reduce(hook_grads, group=process_group) + + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_is_last_hook(self): + + store = dist.FileStore(self.file_name, self.world_size) + process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size) + + def hook(flags, bucket): + flags.append(bucket.is_last()) + fut = torch.futures.Future() + fut.set_result(bucket.buffer()) + return fut + + flags = [] + device_id = gpus_for_rank(self.world_size)[self.rank][0] + model = nn.Sequential( + nn.Linear(2, 4000, bias=False), + *[nn.Linear(4000, 4000, bias=False) for _ in range(10)] + ) + gpu_model = DistributedDataParallel( + model.to(device_id), + device_ids=[device_id], + process_group=process_group, + ) + gpu_model.register_comm_hook(state=flags, hook=hook) + input = torch.randn(10, 2) + gpu_model(input).sum().backward() + self.assertTrue(flags[-1]) + self.assertFalse(any(flags[:-1])) if __name__ == "__main__": diff --git a/test/distributed/algorithms/quantization/test_quantization.py b/test/distributed/algorithms/quantization/test_quantization.py new file mode 100644 index 0000000000000..e60539face11c --- /dev/null +++ b/test/distributed/algorithms/quantization/test_quantization.py @@ -0,0 +1,268 @@ +import torch +import os +import torch.cuda +import sys +import torch.distributed as dist +import torch.distributed.algorithms.quantization.quantization as quant +from torch.distributed.algorithms.quantization.quantization import DQuantType +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_gloo, + skip_if_rocm, + skip_if_lt_x_gpu, + requires_nccl, +) +from torch.testing._internal.distributed.distributed_test import ( + apply_hack_for_nccl +) +from torch.testing._internal.common_utils import sandcastle_skip_if, run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN + +torch.backends.cuda.matmul.allow_tf32 = False + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +def _build_tensor(size, value=None, dtype=torch.float, device_id=None): + if value is None: + value = size + if device_id is None: + return torch.empty(size, dtype=dtype).fill_(value) + else: + return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id) +if TEST_WITH_DEV_DBG_ASAN: + print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) + sys.exit(0) + +if NO_MULTIPROCESSING_SPAWN: + print("Spawn not available, skipping tests.", file=sys.stderr) + sys.exit(0) + +BACKEND = os.environ["BACKEND"] +if BACKEND == "gloo" or BACKEND == "nccl": + class DistQuantizationTests(MultiProcessTestCase): + + def setUp(self): + super(DistQuantizationTests, self).setUp() + self._spawn_processes() + torch.backends.cudnn.flags(allow_tf32=False).__enter__() + + def tearDown(self): + super(DistQuantizationTests, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def op_timeout_sec(self): + return 1 + + @property + def world_size(self): + return int(os.environ["WORLD_SIZE"]) + + def _init_multigpu_helper(self): + """Multigpu tests are designed to simulate the multi nodes with multi + GPUs on each node. Nccl backend requires equal #GPUs in each process. + On a single node, all visible GPUs are evenly + divided to subsets, each process only uses a subset. + """ + nGPUs = torch.cuda.device_count() + world_size = self.world_size + visible_devices = range(nGPUs) + + if BACKEND == "nccl": + apply_hack_for_nccl() + + # If rank is lesser than or equal to number of available GPU's + # then each rank can be mapped to corresponding GPU. + nGPUs_per_process = 1 + if world_size > nGPUs: + nGPUs_per_process = nGPUs // world_size + rank_to_GPU = { + i: list( + visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process] + ) + for i in range(world_size) + } + return rank_to_GPU + + @requires_gloo() + @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports all_gather_fp16") + def test_all_gather_fp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.group.WORLD + self._test_all_gather(group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16) + + @requires_gloo() + @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports all_gather_fp16") + def test_all_gather_bfp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.group.WORLD + self._test_all_gather(group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16) + + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_all_to_all_fp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.FP16) + + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_all_to_all_bfp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.BFP16) + + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_all_to_all_single_fp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all_single( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.FP16 + ) + + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_all_to_all_single_bfp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all_single( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.BFP16 + ) + + def _test_all_gather( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None): + for dest in group: + tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype) + tensors = [_build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group] + expected_tensors = [ + _build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group + ] + if cuda: + tensor = tensor.cuda(rank_to_GPU[rank][0]) + tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] + if tensors[0].dtype == torch.complex64: + tensor_shapes = [torch.view_as_real(tensors[0]).shape] + else: + tensor_shapes = [tensors[0].shape] + allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None) + allgather(tensors, tensor, group=group_id, async_op=False) + + for t1, t2 in zip(tensors, expected_tensors): + self.assertEqual(t1, t2) + + def _test_all_to_all( + self, + group, + group_id, + rank, + cuda=False, + rank_to_GPU=None, + dtype=torch.float, + qtype=None + ): + if group_id is not None: + size = len(group) + in_splits = [i + 1 for i in group] + in_tensors = [ + torch.ones([in_splits[i], size], dtype=dtype) * rank + for i, _ in enumerate(group) + ] + out_tensors = [ + torch.ones([(rank + 1), size], dtype=dtype) for _ in group + ] + expected_tensors = [ + torch.ones([rank + 1, size], dtype=dtype) * i for i in group + ] + if cuda: + in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors] + expected_tensors = [ + t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors + ] + out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors] + quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None) + quantize_alltoall(out_tensors, in_tensors, group=group_id) + for t1, t2 in zip(out_tensors, expected_tensors): + self.assertEqual(t1, t2) + + def _test_all_to_all_single( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=DQuantType.FP16 + ): + if group_id is not None: + size = len(group) + in_splits = [i + 1 for i in group] + out_splits = [rank + 1 for _ in group] + in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank + out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype) + expected_tensor = torch.cat( + [torch.ones([rank + 1, size], dtype=dtype) * i for i in group] + ) + if cuda: + rank_to_GPU = rank_to_GPU[rank][0] + in_tensor = in_tensor.cuda(rank_to_GPU) + expected_tensor = expected_tensor.cuda(rank_to_GPU) + out_tensor = out_tensor.cuda(rank_to_GPU) + quantize_alltoall_single = quant.auto_quantize(dist.all_to_all_single, qtype, quant_loss=None) + quantize_alltoall_single(out_tensor, in_tensor, out_splits=out_splits, in_splits=in_splits, group=group_id) + self.assertEqual(out_tensor, expected_tensor) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py index 9becdeb663ef6..f8972a2be73cf 100644 --- a/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py +++ b/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py @@ -36,8 +36,7 @@ from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.rpc.backend_registry import BackendType from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) @@ -406,19 +405,19 @@ def dummy_compute(self): self.assertEqual((100, 100), return_value.shape) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_dummy_compute_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_dummy_compute_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_dummy_compute_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute) @@ -431,19 +430,19 @@ def run_happy_function(self): self.assertIsNone(res.return_values[1]) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_happy_function_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_happy_function_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_happy_function_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_happy_function) @@ -465,13 +464,13 @@ def check_master_addr_port_override(self): self.assertIsNone(res.return_values[0]) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_check_master_addr_port_override_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.check_master_addr_port_override) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_check_master_addr_port_override_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.check_master_addr_port_override) @@ -484,7 +483,7 @@ def run_check_env_function(self): self.assertFalse(res.is_failed()) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_check_env_function_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_check_env_function) @@ -497,19 +496,19 @@ def run_function_with_return_value(self): self.assertEqual("foo", res.return_values[1]) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_function_with_return_value_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_function_with_return_value) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_function_with_return_value_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_function_with_return_value) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_function_with_return_value_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_function_with_return_value) @@ -520,19 +519,19 @@ def simple_dist_sum(self): # _dist_sum internally checks that the sum computed is valid @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_simple_dist_sum_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_simple_dist_sum_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_simple_dist_sum_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum) @@ -556,19 +555,19 @@ def run_distributed_sum_homogeneous(self): self.assertSetEqual(set(range(4 + 4)), ranks) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_homogeneous_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_homogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_homogeneous_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_homogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_homogeneous_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous) @@ -596,19 +595,19 @@ def run_distributed_sum_heterogeneous(self): self.assertSetEqual(set(range(1 + 2 + 3)), ranks) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_heterogeneous_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_heterogeneous_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_distributed_sum_heterogeneous_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous) @@ -636,19 +635,19 @@ def run_sad_function(self): self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_sad_function_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_sad_function_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_sad_function_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function) @@ -668,19 +667,19 @@ def run_bipolar_function(self): self.assertTrue(agent._total_execution_time > 0) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_bipolar_function_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.run_bipolar_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_bipolar_function_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.run_bipolar_function) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_run_bipolar_function_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_bipolar_function) @@ -711,13 +710,13 @@ def correct_rank_assignment_heterogeneous(self): ) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_correct_rank_assignment_heterogeneous_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_correct_rank_assignment_heterogeneous_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous) @@ -744,13 +743,13 @@ def correct_rank_assignment_homogeneous(self): ) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_correct_rank_assignment_homogeneous_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_correct_rank_assignment_homogeneous_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous) @@ -852,13 +851,13 @@ def double_agent_fault_tolerance(self): self.assertEqual(0, p.exitcode) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_double_agent_fault_tolerance_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_fault_tolerance) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_double_agent_fault_tolerance_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance) @@ -905,19 +904,19 @@ def double_agent_elastic(self): self.assertEqual(-signal.SIGKILL, p.exitcode) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_double_agent_elastic_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.double_agent_elastic) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_double_agent_elastic_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.double_agent_elastic) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_double_agent_elastic_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.double_agent_elastic) @@ -955,19 +954,19 @@ def torch_rpc(self): self.assertEqual([f"{msg} from worker"], list(master_retvals.values())) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_torch_rpc_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_torch_rpc_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_torch_rpc_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc) @@ -993,13 +992,13 @@ def workers_drift_success(self): self.assertEqual(rank, output) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_workers_drift_success_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_success) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_workers_drift_success_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_success) @@ -1024,13 +1023,13 @@ def workers_drift_fail(self): self.assertEqual(rank, output) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_workers_drift_fail_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_workers_drift_fail_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.workers_drift_fail) @@ -1047,19 +1046,19 @@ def barrier_failed(self, barrier_mock): barrier_mock.assert_called_once() @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_barrier_failed_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_barrier_failed_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_barrier_failed_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed) @@ -1081,19 +1080,19 @@ def shutdown_called(self, start_processes_mock): pcontext_mock.close.assert_called_once() @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_shutdown_called_c10d(self): self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_shutdown_called_etcd(self): self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_shutdown_called_etcd_v2(self): self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called) diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index cb1db294d2791..811137a8d83b4 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -36,6 +36,7 @@ NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN, TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, IS_IN_CI, IS_WINDOWS, IS_MACOS, @@ -222,15 +223,11 @@ def start_processes_zombie_test( # tests incompatible with tsan or asan -if not (TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS): +if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): class StartProcessesTest(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") - - if NO_MULTIPROCESSING_SPAWN: # python 2.7 doesn't have spawn - self._start_methods = ["fork"] - else: - self._start_methods = ["fork", "spawn"] + self._start_methods = ["spawn"] def tearDown(self): shutil.rmtree(self.test_dir) @@ -316,7 +313,7 @@ def test_pcontext_wait(self): args={0: (1,)}, envs={0: {}}, log_dir=self.log_dir(), - start_method="fork", + start_method="spawn", ) self.assertIsNone(pc.wait(timeout=0.1, period=0.01)) @@ -331,7 +328,7 @@ def test_multiprocess_context_close(self): args={0: (1,)}, envs={0: {}}, log_dir=self.log_dir(), - start_method="fork", + start_method="spawn", ) pids = pc.pids() @@ -386,7 +383,7 @@ def test_void_function(self): self.assertEqual({0: None, 1: None}, results.return_values) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan" ) def test_function_large_ret_val(self): # python multiprocessing.queue module uses pipes and actually PipedQueues @@ -548,7 +545,7 @@ def test_multiprocessing_context_poll_raises_exception(self): # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows -if not (TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS): +if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): class StartProcessesListTest(StartProcessesTest): ######################################## # start_processes as binary tests @@ -629,7 +626,7 @@ def test_binary_redirect_and_tee(self): args={0: ("hello",), 1: ("world",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=self.log_dir(), - start_method="fork", + start_method="spawn", redirects={0: Std.ERR, 1: Std.NONE}, tee={0: Std.OUT, 1: Std.ERR}, ) @@ -646,7 +643,7 @@ def test_binary_redirect_and_tee(self): # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows -if not (TEST_WITH_ASAN or TEST_WITH_TSAN or IS_WINDOWS or IS_MACOS or IS_IN_CI): +if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI): class StartProcessesNotCITest(StartProcessesTest): def test_wrap_bad(self): none = "" @@ -696,7 +693,11 @@ def test_binary_signal(self): failure = results.failures[0] self.assertNotEqual(signal.SIGSEGV, failure.exitcode) - self.assertEqual("SIGSEGV", failure.signal_name()) + if TEST_WITH_ASAN or TEST_WITH_TSAN: + # ASAN/TSAN exit code is 1. + self.assertEqual("", failure.signal_name()) + else: + self.assertEqual("SIGSEGV", failure.signal_name()) self.assertEqual("", failure.error_file_data["message"]) def test_function_redirect_and_tee(self): @@ -709,7 +710,7 @@ def test_function_redirect_and_tee(self): args={0: ("hello",), 1: ("world",)}, envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, log_dir=log_dir, - start_method="fork", + start_method="spawn", redirects={0: Std.ERR, 1: Std.NONE}, tee={0: Std.OUT, 1: Std.ERR}, ) diff --git a/test/distributed/elastic/multiprocessing/errors/api_test.py b/test/distributed/elastic/multiprocessing/errors/api_test.py index 14b7ab1d13970..859069004ae71 100644 --- a/test/distributed/elastic/multiprocessing/errors/api_test.py +++ b/test/distributed/elastic/multiprocessing/errors/api_test.py @@ -13,7 +13,6 @@ record, ) from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error -from torch.testing._internal.common_utils import TEST_WITH_TSAN class SentinelError(Exception): @@ -45,10 +44,6 @@ def read_resource_file(resource_file: str) -> str: return "".join(fp.readlines()) -if TEST_WITH_TSAN: - print("test incompatible with tsan", file=sys.stderr) - sys.exit(0) - class ApiTest(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) diff --git a/test/distributed/elastic/timer/local_timer_example.py b/test/distributed/elastic/timer/local_timer_example.py index d73aa67ee75e7..b52c64752e413 100644 --- a/test/distributed/elastic/timer/local_timer_example.py +++ b/test/distributed/elastic/timer/local_timer_example.py @@ -14,8 +14,7 @@ import torch.distributed.elastic.timer as timer import torch.multiprocessing as torch_mp from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, run_tests, IS_WINDOWS, IS_MACOS, @@ -55,7 +54,7 @@ class LocalTimerExample(unittest.TestCase): unittest. As of now this will SIGSEGV. """ - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible") def test_torch_mp_example(self): # in practice set the max_interval to a larger value (e.g. 60 seconds) mp_queue = mp.get_context("spawn").Queue() @@ -80,18 +79,14 @@ def test_torch_mp_example(self): server.stop() - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible") def test_example_start_method_spawn(self): self._run_example_with(start_method="spawn") - # @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test is a/tsan incompatible") + # @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible") # def test_example_start_method_forkserver(self): # self._run_example_with(start_method="forkserver") - @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible") - def test_example_start_method_fork(self): - self._run_example_with(start_method="fork") - def _run_example_with(self, start_method): spawn_ctx = mp.get_context(start_method) mp_queue = spawn_ctx.Queue() diff --git a/test/distributed/elastic/timer/local_timer_test.py b/test/distributed/elastic/timer/local_timer_test.py index 4c977113aa42e..f27e5939660e5 100644 --- a/test/distributed/elastic/timer/local_timer_test.py +++ b/test/distributed/elastic/timer/local_timer_test.py @@ -13,19 +13,28 @@ from torch.distributed.elastic.timer.api import TimerRequest from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue from torch.testing._internal.common_utils import ( - TEST_WITH_TSAN, run_tests, IS_WINDOWS, IS_MACOS, - sandcastle_skip_if, + TEST_WITH_DEV_DBG_ASAN, ) # timer is not supported on windows or macos -if not (IS_WINDOWS or IS_MACOS): +if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): + # func2 should time out + def func2(n, mp_queue): + if mp_queue is not None: + timer.configure(timer.LocalTimerClient(mp_queue)) + if n > 0: + with timer.expires(after=0.1): + func2(n - 1, None) + time.sleep(0.2) + class LocalTimerTest(unittest.TestCase): def setUp(self): - self.mp_queue = mp.Queue() + self.ctx = mp.get_context('spawn') + self.mp_queue = self.ctx.Queue() self.max_interval = 0.01 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) self.server.start() @@ -62,7 +71,6 @@ def test_happy_path(self): with timer.expires(after=0.5): time.sleep(0.1) - @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible") def test_get_timer_recursive(self): """ If a function acquires a countdown timer with default scope, @@ -82,14 +90,7 @@ def func(n): func(4) - # func2 should time out - def func2(n): - if n > 0: - with timer.expires(after=0.1): - func2(n - 1) - time.sleep(0.2) - - p = mp.Process(target=func2, args=(2,)) + p = self.ctx.Process(target=func2, args=(2, self.mp_queue)) p.start() p.join() self.assertEqual(-signal.SIGKILL, p.exitcode) @@ -102,7 +103,6 @@ def _run(mp_queue, timeout, duration): with timer.expires(after=timeout): time.sleep(duration) - @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible") def test_timer(self): timeout = 0.1 duration = 1 @@ -124,7 +124,7 @@ def _enqueue_on_interval(mp_queue, n, interval, sem): # timer is not supported on windows or macos -if not (IS_WINDOWS or IS_MACOS): +if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): class MultiprocessingRequestQueueTest(unittest.TestCase): def test_get(self): mp_queue = mp.Queue() @@ -183,7 +183,7 @@ def test_get_less_than_size(self): # timer is not supported on windows or macos -if not (IS_WINDOWS or IS_MACOS): +if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): class LocalTimerServerTest(unittest.TestCase): def setUp(self): self.mp_queue = mp.Queue() @@ -193,7 +193,6 @@ def setUp(self): def tearDown(self): self.server.stop() - @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible") def test_watchdog_call_count(self): """ checks that the watchdog function ran wait/interval +- 1 times @@ -226,7 +225,6 @@ def _valid_timer(self, pid, scope): def _release_timer(self, pid, scope): return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1) - @sandcastle_skip_if(TEST_WITH_TSAN, "test is tsan incompatible") @mock.patch("os.kill") def test_expired_timers(self, mock_os_kill): """ diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index 954b7e201a351..685e843c10653 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -30,8 +30,7 @@ _get_entrypoint_name, ) from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) @@ -117,7 +116,7 @@ def get_test_launch_config( rdzv_endpoint=endpoint, monitor_interval=1, rdzv_backend=rdzv_backend, - start_method="fork", + start_method="spawn", max_restarts=0, rdzv_configs=rdzv_configs, ) @@ -128,7 +127,7 @@ def check_works_ran(self, world_size: int): ) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_script_python(self): nnodes = 1 @@ -145,7 +144,7 @@ def test_launch_script_python(self): self.check_works_ran(world_size) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_script_python_local_rank_transfer(self): nnodes = 1 @@ -162,7 +161,7 @@ def test_launch_script_python_local_rank_transfer(self): self.check_works_ran(world_size) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_script_bash(self): nnodes = 1 @@ -177,7 +176,7 @@ def test_launch_script_bash(self): self.check_works_ran(world_size) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_function(self): nnodes = 1 @@ -193,7 +192,7 @@ def test_launch_function(self): self.assertEqual(expected_res, actual_res) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_dist_sum_with_static_rdzv(self): nnodes = 1 @@ -224,7 +223,7 @@ def test_launch_dist_sum_with_static_rdzv(self): self.assertEqual(expected_res, actual_res) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_elastic(self): nproc_per_node = 4 diff --git a/test/distributed/launcher/bin/test_script_init_method.py b/test/distributed/launcher/bin/test_script_init_method.py new file mode 100755 index 0000000000000..299839c40759b --- /dev/null +++ b/test/distributed/launcher/bin/test_script_init_method.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +def parse_args(): + parser = argparse.ArgumentParser(description="test script") + + parser.add_argument( + "--init_method", + type=str, + required=True, + help="init_method to pass to `dist.init_process_group()` (e.g. env://)", + ) + parser.add_argument( + "--world_size", + type=int, + default=os.getenv("WORLD_SIZE", -1), + help="world_size to pass to `dist.init_process_group()`", + ) + parser.add_argument( + "--rank", + type=int, + default=os.getenv("RANK", -1), + help="rank to pass to `dist.init_process_group()`", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + dist.init_process_group( + backend="gloo", + init_method=args.init_method, + world_size=args.world_size, + rank=args.rank, + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + # one hot (by rank) tensor of size world_size + # example: + # rank 0, world_size 4 => [1, 0, 0, 0] + # rank 1, world_size 4 => [0, 1, 0, 0] + # ... + t = F.one_hot(torch.tensor(rank), num_classes=world_size) + + # after all_reduce t = tensor.ones(size=world_size) + dist.all_reduce(t) + + # adding all elements in t should equal world_size + derived_world_size = torch.sum(t).item() + if derived_world_size != world_size: + raise RuntimeError( + f"Wrong world size derived. Expected: {world_size}, Got: {derived_world_size}" + ) + + print("Done") + + +if __name__ == "__main__": + main() diff --git a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py new file mode 100755 index 0000000000000..fa9729c757b64 --- /dev/null +++ b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This is a test script that launches as part of the test cases in +run_test.py, to validate the correctness of +the method ``torch.distributed.is_torchelastic_launched()``. To do so, +we run this script with and without torchelastic and validate that the +boolean value written to the out_file is indeed what we expect (e.g. +should be False when not launched with torchelastic, True when launched with) +The script itself is not a test case hence no assertions are made in this script. + +see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() + - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() +""" +import argparse + +import torch.distributed as dist + + +def parse_args(): + parser = argparse.ArgumentParser(description="test script") + parser.add_argument( + "--out_file", + help="file to write indicating whether this script was launched with torchelastic", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + with open(args.out_file, "w") as out: + out.write(f"{dist.is_torchelastic_launched()}") + + +if __name__ == "__main__": + main() diff --git a/test/distributed/launcher/launch_test.py b/test/distributed/launcher/launch_test.py index 2d27269014246..d79a18d39b995 100644 --- a/test/distributed/launcher/launch_test.py +++ b/test/distributed/launcher/launch_test.py @@ -14,8 +14,7 @@ import torch.distributed.launch as launch from torch.distributed.elastic.utils import get_socket_with_port from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) @@ -36,7 +35,7 @@ def tearDown(self): shutil.rmtree(self.test_dir) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_without_env(self): nnodes = 1 @@ -49,7 +48,7 @@ def test_launch_without_env(self): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", @@ -58,7 +57,7 @@ def test_launch_without_env(self): launch.main(args) @sandcastle_skip_if( - TEST_WITH_ASAN or TEST_WITH_TSAN, "tests incompatible with tsan or asan" + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" ) def test_launch_with_env(self): nnodes = 1 @@ -71,7 +70,7 @@ def test_launch_with_env(self): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py index 7318bbd630687..4ed824c036390 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/run_test.py @@ -7,8 +7,10 @@ # LICENSE file in the root directory of this source tree. import multiprocessing as mp import os +import runpy import shutil import subprocess +import sys import tempfile import unittest import uuid @@ -21,9 +23,9 @@ from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.elastic.utils import get_socket_with_port +from torch.distributed.elastic.utils.distributed import get_free_port from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, ) @@ -100,7 +102,7 @@ def test_launch_user_script_python(self): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -123,7 +125,7 @@ def test_launch_user_script_python_caffe2_bc(self): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", @@ -138,7 +140,7 @@ def test_launch_user_script_python_caffe2_bc(self): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_user_script_bash(self): run_id = str(uuid.uuid4().int) nnodes = 1 @@ -151,7 +153,7 @@ def test_launch_user_script_bash(self): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--no_python", ] @@ -169,7 +171,7 @@ def test_launch_user_script_bash(self): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_user_script_default_nproc(self): run_id = str(uuid.uuid4().int) nnodes = 1 @@ -180,7 +182,7 @@ def test_launch_user_script_default_nproc(self): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--no_python", ] @@ -198,7 +200,7 @@ def test_launch_user_script_default_nproc(self): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_with_env_vars(self): run_id = str(uuid.uuid4().int) nnodes = 1 @@ -211,7 +213,7 @@ def test_launch_with_env_vars(self): os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint os.environ["PET_RDZV_ID"] = run_id os.environ["PET_MONITOR_INTERVAL"] = "1" - os.environ["PET_START_METHOD"] = "fork" + os.environ["PET_START_METHOD"] = "spawn" os.environ["PET_NO_PYTHON"] = "1" script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] @@ -241,7 +243,7 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--no_python", ] @@ -256,27 +258,27 @@ def _test_nproc_launch_configuration(self, nproc_type, expected_number): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_nproc_launch_auto_configurations(self): self._test_nproc_launch_configuration("auto", os.cpu_count()) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_nproc_launch_number_configurations(self): self._test_nproc_launch_configuration("4", 4) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_nproc_launch_unknown_configurations(self): with self.assertRaises(ValueError): self._test_nproc_launch_configuration("unknown", 4) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") @patch("torch.cuda.is_available", return_value=True) @patch("torch.cuda.device_count", return_value=3) def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): self._test_nproc_launch_configuration("auto", 3) self._test_nproc_launch_configuration("gpu", 3) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_elastic(self): run_id = str(uuid.uuid4().int) min_nodes = 1 @@ -291,7 +293,7 @@ def test_launch_elastic(self): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -304,7 +306,7 @@ def test_launch_elastic(self): ) @mock.patch("torch.distributed.elastic.events.record") - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_elastic_worker_raise_exception(self, record_mock): """ Asserts that when the worker program fails and lancher raieses exception @@ -323,7 +325,7 @@ def test_launch_elastic_worker_raise_exception(self, record_mock): f"--rdzv_id={run_id}", "--monitor_interval=1", "--max_restarts=0", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), "--fail", ] @@ -332,7 +334,7 @@ def test_launch_elastic_worker_raise_exception(self, record_mock): record_mock.assert_called_once() - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") @mock.patch( "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run" ) @@ -354,7 +356,7 @@ def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run) f"--rdzv_id={run_id}", "--monitor_interval=1", "--max_restarts=0", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -364,7 +366,7 @@ def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run) launch.main(args) record_mock.assert_called_once() - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_standalone(self): nnodes = 1 nproc_per_node = 4 @@ -374,7 +376,7 @@ def test_launch_standalone(self): f"--nproc_per_node={nproc_per_node}", "--standalone", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -386,7 +388,7 @@ def test_launch_standalone(self): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_run_path(self): nnodes = 1 nproc_per_node = 4 @@ -396,7 +398,7 @@ def test_launch_run_path(self): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -408,7 +410,7 @@ def test_launch_run_path(self): {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) ) - @sandcastle_skip_if(TEST_WITH_ASAN or TEST_WITH_TSAN, "test incompatible with tsan") + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") def test_launch_elastic_multiple_agents(self): run_id = str(uuid.uuid4().int) min_nodes = 1 @@ -423,7 +425,7 @@ def test_launch_elastic_multiple_agents(self): f"--rdzv_endpoint={self._etcd_endpoint}", f"--rdzv_id={run_id}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -462,7 +464,7 @@ def test_launch_shutdown(self, agent_mock_cls): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", path("bin/test_script.py"), f"--touch_file_dir={self.test_dir}", ] @@ -476,3 +478,117 @@ def test_launch_shutdown(self, agent_mock_cls): param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once() + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_is_torchelastic_launched(self): + # launch test script with torchelastic and validate that + # torch.distributed.is_torchelastic_launched() returns True + + out_file = f"{os.path.join(self.test_dir, 'out')}" + + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=1", + "--monitor_interval=1", + path("bin/test_script_is_torchelastic_launched.py"), + f"--out_file={out_file}", + ] + ) + + with open(out_file, "r") as fp: + is_torchelastic_launched = fp.readline() + self.assertEqual("True", is_torchelastic_launched) + + def test_is_not_torchelastic_launched(self): + # launch test script without torchelastic and validate that + # torch.distributed.is_torchelastic_launched() returns False + + out_file = f"{os.path.join(self.test_dir, 'out')}" + + # need to run the script with runpy in the same interpreter + # as the test because otherwise (depending on the environment) + # it will not find torch as a dependency + with patch.object( + sys, + "argv", + [ + path("bin/test_script_is_torchelastic_launched.py"), + f"--out_file={out_file}", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + with open(out_file, "r") as fp: + is_torchelastic_launched = fp.readline() + self.assertEqual("False", is_torchelastic_launched) + + def test_init_method_tcp(self): + port = get_free_port() + with patch.object( + sys, + "argv", + [ + path("bin/test_script_init_method.py"), + f"--init_method=tcp://localhost:{port}", + "--rank=0", + "--world_size=1", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + # nothing to validate, just make sure it runs + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_init_method_tcp_with_torchelastic(self): + port = get_free_port() + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=4", + "--master_addr=localhost", + f"--master_port={port}", + "--monitor_interval=1", + path("bin/test_script_init_method.py"), + f"--init_method=tcp://localhost:{port}", + ] + ) + # nothing to validate, just make sure it runs + + def test_init_method_env(self): + port = get_free_port() + with patch.dict( + os.environ, + { + "RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(port), + }, + ), patch.object( + sys, + "argv", + [ + path("bin/test_script_init_method.py"), + "--init_method=env://", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + # nothing to validate, just make sure it runs + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_init_method_env_with_torchelastic(self): + port = get_free_port() + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=4", + "--master_addr=localhost", + f"--master_port={port}", + "--monitor_interval=1", + path("bin/test_script_init_method.py"), + "--init_method=env://", + ] + ) + # nothing to validate, just make sure it runs diff --git a/test/distributed/rpc/cuda/test_tensorpipe_agent.py b/test/distributed/rpc/cuda/test_tensorpipe_agent.py index 5647434f6f53e..7cb35f9f73aa1 100644 --- a/test/distributed/rpc/cuda/test_tensorpipe_agent.py +++ b/test/distributed/rpc/cuda/test_tensorpipe_agent.py @@ -15,7 +15,6 @@ from torch.testing._internal.distributed.rpc_utils import ( GENERIC_CUDA_TESTS, TENSORPIPE_CUDA_TESTS, - MultiProcess, generate_tests, ) @@ -25,7 +24,6 @@ "TensorPipe", TensorPipeRpcAgentTestFixture, GENERIC_CUDA_TESTS + TENSORPIPE_CUDA_TESTS, - MultiProcess.SPAWN, __name__, ) ) diff --git a/test/distributed/rpc/test_faulty_agent.py b/test/distributed/rpc/test_faulty_agent.py index 7c26643ab6b60..cb889115be8a1 100644 --- a/test/distributed/rpc/test_faulty_agent.py +++ b/test/distributed/rpc/test_faulty_agent.py @@ -15,7 +15,6 @@ ) from torch.testing._internal.distributed.rpc_utils import ( FAULTY_AGENT_TESTS, - MultiProcess, generate_tests, ) @@ -28,7 +27,6 @@ "Faulty", FaultyRpcAgentTestFixture, FAULTY_AGENT_TESTS, - MultiProcess.SPAWN, __name__, ) ) diff --git a/test/distributed/rpc/test_tensorpipe_agent.py b/test/distributed/rpc/test_tensorpipe_agent.py index 32b0e1c69357a..b741bc443c460 100644 --- a/test/distributed/rpc/test_tensorpipe_agent.py +++ b/test/distributed/rpc/test_tensorpipe_agent.py @@ -16,7 +16,6 @@ from torch.testing._internal.distributed.rpc_utils import ( GENERIC_TESTS, TENSORPIPE_TESTS, - MultiProcess, generate_tests, ) @@ -29,7 +28,6 @@ "TensorPipe", TensorPipeRpcAgentTestFixture, GENERIC_TESTS + TENSORPIPE_TESTS, - MultiProcess.SPAWN, __name__, ) ) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 6aa5c64658415..33939d093ca3f 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -28,9 +28,13 @@ TestCase, load_tests, run_tests, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, ) +if TEST_WITH_DEV_DBG_ASAN: + print("Multiprocessing spawn is not compatible with dev/dbg asan", file=sys.stderr) + sys.exit(0) + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -438,37 +442,31 @@ def fut_then(fut): return fut.then(fut_then) -# TSAN is not fork-safe since we're forking in a multi-threaded environment -if not TEST_WITH_TSAN: - - class DistributedDataParallelTest( - AbstractDistributedDataParallelTest, MultiProcessTestCase - ): - def setUp(self): - super(DistributedDataParallelTest, self).setUp() - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() - - def test_invalid_powerSGD_state(self): - for start_powerSGD_iter, use_error_feedback, warm_start in product( - [0, 1], [True, False], [True, False] +class DistributedDataParallelTest( + AbstractDistributedDataParallelTest, MultiProcessTestCase +): + def setUp(self): + super(DistributedDataParallelTest, self).setUp() + self._spawn_processes() + + def test_invalid_powerSGD_state(self): + for start_powerSGD_iter, use_error_feedback, warm_start in product( + [0, 1], [True, False], [True, False] + ): + if not use_error_feedback and not warm_start: + continue + with self.assertRaisesRegex( + ValueError, + "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " + "because PowerSGD can only be applied after the first two iterations in DDP.", ): - if not use_error_feedback and not warm_start: - continue - with self.assertRaisesRegex( - ValueError, - "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, " - "because PowerSGD can only be applied after the first two iterations in DDP.", - ): - state = powerSGD.PowerSGDState( - process_group=None, - matrix_approximation_rank=1, - start_powerSGD_iter=start_powerSGD_iter, - use_error_feedback=use_error_feedback, - warm_start=warm_start, - ) + state = powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + start_powerSGD_iter=start_powerSGD_iter, + use_error_feedback=use_error_feedback, + warm_start=warm_start, + ) class ComputeBucketAssignmentTest(TestCase): @@ -656,49 +654,42 @@ def _test_sequence_num_set_new_group(self, backend): dist.all_gather_object(obj_list, subgroup_seq, group=subgroup) self.assertEqual(len(set(obj_list)), 1) +class CommTest(AbstractCommTest, MultiProcessTestCase): + def setUp(self): + super(CommTest, self).setUp() + self._spawn_processes() -# TSAN is not fork-safe since we're forking in a multi-threaded environment -if not TEST_WITH_TSAN: + def tearDown(self): + super(CommTest, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass - class CommTest(AbstractCommTest, MultiProcessTestCase): - def setUp(self): - super(CommTest, self).setUp() - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() - - def tearDown(self): - super(CommTest, self).tearDown() - try: - os.remove(self.file_name) - except OSError: - pass - - def test_distributed_debug_mode(self): - # Default should be off - default_debug_mode = dist._get_debug_mode() - self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF) - mapping = { - "OFF": dist._DistributedDebugLevel.OFF, - "INFO": dist._DistributedDebugLevel.INFO, - "DETAIL": dist._DistributedDebugLevel.DETAIL, - } - invalid_debug_modes = ["foo", 0, 1, -1] - - for mode in mapping.keys(): - os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) - set_debug_mode = dist._get_debug_mode() - self.assertEqual( - set_debug_mode, - mapping[mode], - f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}", - ) + def test_distributed_debug_mode(self): + # Default should be off + default_debug_mode = dist._get_debug_mode() + self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF) + mapping = { + "OFF": dist._DistributedDebugLevel.OFF, + "INFO": dist._DistributedDebugLevel.INFO, + "DETAIL": dist._DistributedDebugLevel.DETAIL, + } + invalid_debug_modes = ["foo", 0, 1, -1] + + for mode in mapping.keys(): + os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) + set_debug_mode = dist._get_debug_mode() + self.assertEqual( + set_debug_mode, + mapping[mode], + f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}", + ) - for mode in invalid_debug_modes: - os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) - with self.assertRaisesRegex(RuntimeError, "to be one of"): - dist._get_debug_mode() + for mode in invalid_debug_modes: + os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) + with self.assertRaisesRegex(RuntimeError, "to be one of"): + dist._get_debug_mode() if __name__ == "__main__": diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 54f29f3b11a7b..789d76e9d115a 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -43,17 +43,9 @@ TestCase, run_tests, retry_on_connect_failures, - TEST_WITH_TSAN, sandcastle_skip, ) -if TEST_WITH_TSAN: - print( - "Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", - file=sys.stderr, - ) - sys.exit(0) - def simple_reduce_tests(rank, world_size): tests = [ @@ -218,12 +210,7 @@ def _create_process_group_gloo(self, store, rank, world_size, opts): def setUp(self): super(ProcessGroupGlooTest, self).setUp() - - # For Windows platform, Python does not support fork, change it to spawn here. - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() + self._spawn_processes() def opts(self, threads=2): opts = c10d.ProcessGroupGloo._Options() @@ -272,43 +259,43 @@ def test_broadcast_checks(self): t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.BroadcastOptions() opts.rootRank = -1 opts.rootTensor = 0 pg.broadcast([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.BroadcastOptions() opts.rootRank = self.world_size opts.rootTensor = 0 pg.broadcast([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root tensor"): + with self.assertRaisesRegex(RuntimeError, "invalid root tensor"): opts = c10d.BroadcastOptions() opts.rootRank = self.rank opts.rootTensor = -1 pg.broadcast([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root tensor"): + with self.assertRaisesRegex(RuntimeError, "invalid root tensor"): opts = c10d.BroadcastOptions() opts.rootRank = self.rank opts.rootTensor = 1 pg.broadcast([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root tensor"): + with self.assertRaisesRegex(RuntimeError, "invalid root tensor"): opts = c10d.BroadcastOptions() opts.rootRank = self.rank opts.rootTensor = 0 pg.broadcast([], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): opts = c10d.BroadcastOptions() opts.rootRank = self.rank opts.rootTensor = 0 pg.broadcast([t1, t2], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): opts = c10d.BroadcastOptions() opts.rootRank = self.rank opts.rootTensor = 0 @@ -407,15 +394,15 @@ def test_allreduce_checks(self): t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "requires non-empty tensor list"): + with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"): opts = c10d.AllreduceOptions() pg.allreduce([], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): opts = c10d.AllreduceOptions() pg.allreduce([t1, t2], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): opts = c10d.AllreduceOptions() pg.allreduce([t1, t3], opts) @@ -566,19 +553,19 @@ def test_allreduce_coalesced_checks(self): t2 = torch.zeros(1, dtype=torch.float64) t3 = torch.sparse_coo_tensor([[0]], [1], size=(1,)) - with self.assertRaisesRegex(ValueError, "requires non-empty tensor list"): + with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"): opts = c10d.AllreduceCoalescedOptions() pg.allreduce_coalesced([], opts) - with self.assertRaisesRegex(ValueError, "tensors must all have the same type"): + with self.assertRaisesRegex(RuntimeError, "tensors must all have the same type"): opts = c10d.AllreduceCoalescedOptions() pg.allreduce_coalesced([t1, t2], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor layout at index"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor layout at index"): opts = c10d.AllreduceCoalescedOptions() pg.allreduce_coalesced([t1, t3], opts) - with self.assertRaisesRegex(ValueError, "unsupported layout"): + with self.assertRaisesRegex(RuntimeError, "unsupported layout"): opts = c10d.AllreduceCoalescedOptions() pg.allreduce_coalesced([t3, t3.clone()], opts) @@ -592,7 +579,7 @@ def test_allreduce_coalesced_checks_cuda(self): t1 = torch.zeros(1, dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "unsupported device type"): + with self.assertRaisesRegex(RuntimeError, "unsupported device type"): opts = c10d.AllreduceCoalescedOptions() pg.allreduce_coalesced([t1.cuda(), t1.cuda()], opts) @@ -660,21 +647,21 @@ def test_sparse_allreduce_checks(self): t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,)) t3 = torch.sparse_coo_tensor([[0]], [1], size=(4,)) - with self.assertRaisesRegex(ValueError, "requires non-empty tensor list"): + with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"): opts = c10d.AllreduceOptions() pg.allreduce([], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor layout"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor layout"): opts = c10d.AllreduceOptions() pg.allreduce([t1, t2], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): opts = c10d.AllreduceOptions() pg.allreduce([t2, t3], opts) # Sparse allreduce only works with c10d.ReduceOp.SUM. for op in [c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX]: - with self.assertRaisesRegex(ValueError, "unsupported reduction operation"): + with self.assertRaisesRegex(RuntimeError, "unsupported reduction operation"): opts = c10d.AllreduceOptions() opts.reduceOp = op pg.allreduce([t3], opts) @@ -718,36 +705,36 @@ def test_scatter_checks(self): t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ScatterOptions() opts.rootRank = -1 pg.scatter([t1], [], opts) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ScatterOptions() opts.rootRank = self.world_size pg.scatter([t1], [], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element output tensor list" + RuntimeError, "requires a single-element output tensor list" ): opts = c10d.ScatterOptions() opts.rootRank = 0 pg.scatter([], [], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element output tensor list" + RuntimeError, "requires a single-element output tensor list" ): opts = c10d.ScatterOptions() opts.rootRank = 0 pg.scatter([t1, t1], [], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element input list"): + with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [], opts) - with self.assertRaisesRegex(ValueError, "requires a single-element input list"): + with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t1] * self.world_size, [t1] * self.world_size], opts) @@ -756,7 +743,7 @@ def test_scatter_checks(self): incorrect_list_size = self.world_size - 1 err_str = "Incorrect input list size {}. Input list size should be {}" with self.assertRaisesRegex( - ValueError, err_str.format(incorrect_list_size, desired_list_size) + RuntimeError, err_str.format(incorrect_list_size, desired_list_size) ): opts = c10d.ScatterOptions() opts.rootRank = self.rank @@ -764,23 +751,23 @@ def test_scatter_checks(self): incorrect_list_size = self.world_size + 1 with self.assertRaisesRegex( - ValueError, err_str.format(incorrect_list_size, desired_list_size) + RuntimeError, err_str.format(incorrect_list_size, desired_list_size) ): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t1] * incorrect_list_size], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t2] * self.world_size], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): opts = c10d.ScatterOptions() opts.rootRank = self.rank pg.scatter([t1], [[t3] * self.world_size], opts) - with self.assertRaisesRegex(ValueError, "requires empty input on non-root"): + with self.assertRaisesRegex(RuntimeError, "requires empty input on non-root"): opts = c10d.ScatterOptions() opts.rootRank = (self.rank + 1) % self.world_size pg.scatter([t1], [[t1] * self.world_size], opts) @@ -885,39 +872,39 @@ def test_gather_checks(self): t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.GatherOptions() opts.rootRank = -1 pg.gather([], [t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.GatherOptions() opts.rootRank = self.world_size pg.gather([], [t1], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element input tensor list" + RuntimeError, "requires a single-element input tensor list" ): opts = c10d.GatherOptions() opts.rootRank = 0 pg.gather([], [], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element input tensor list" + RuntimeError, "requires a single-element input tensor list" ): opts = c10d.GatherOptions() opts.rootRank = 0 pg.gather([], [t1, t1], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element output list" + RuntimeError, "requires a single-element output list" ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([], [t1], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element output list" + RuntimeError, "requires a single-element output list" ): opts = c10d.GatherOptions() opts.rootRank = self.rank @@ -927,7 +914,7 @@ def test_gather_checks(self): incorrect_list_size = self.world_size - 1 err_str = "Incorrect output list size {}. Output list size should be {}" with self.assertRaisesRegex( - ValueError, err_str.format(incorrect_list_size, desired_list_size) + RuntimeError, err_str.format(incorrect_list_size, desired_list_size) ): opts = c10d.GatherOptions() opts.rootRank = self.rank @@ -935,23 +922,23 @@ def test_gather_checks(self): incorrect_list_size = self.world_size + 1 with self.assertRaisesRegex( - ValueError, err_str.format(incorrect_list_size, desired_list_size) + RuntimeError, err_str.format(incorrect_list_size, desired_list_size) ): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t1] * incorrect_list_size], [t1], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t2] * self.world_size], [t1], opts) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): opts = c10d.GatherOptions() opts.rootRank = self.rank pg.gather([[t3] * self.world_size], [t1], opts) - with self.assertRaisesRegex(ValueError, "requires empty output on non-root"): + with self.assertRaisesRegex(RuntimeError, "requires empty output on non-root"): opts = c10d.GatherOptions() opts.rootRank = (self.rank + 1) % self.world_size pg.gather([[t1] * self.world_size], [t1], opts) @@ -1052,39 +1039,39 @@ def test_allgather_checks(self): t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "requires non-empty input tensor list"): + with self.assertRaisesRegex(RuntimeError, "requires non-empty input tensor list"): pg.allgather([], []) with self.assertRaisesRegex( - ValueError, "requires input/output tensor lists to have the same length" + RuntimeError, "requires input/output tensor lists to have the same length" ): pg.allgather([], [t1]) with self.assertRaisesRegex( - ValueError, "requires input/output tensor lists to have the same length" + RuntimeError, "requires input/output tensor lists to have the same length" ): pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1]) - with self.assertRaisesRegex(ValueError, "invalid output tensor list"): + with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"): pg.allgather([[t1] * (self.world_size - 1)], [t1]) - with self.assertRaisesRegex(ValueError, "invalid output tensor list"): + with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"): pg.allgather([[t1] * (self.world_size + 1)], [t1]) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): pg.allgather( [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2] ) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): pg.allgather( [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3] ) - with self.assertRaisesRegex(ValueError, "invalid tensor type"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type"): pg.allgather([([t1, t2] * (self.world_size))[: self.world_size]], [t1]) - with self.assertRaisesRegex(ValueError, "invalid tensor size"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor size"): pg.allgather([([t1, t3] * (self.world_size))[: self.world_size]], [t1]) def _test_allgather_basics(self, fn): @@ -1173,13 +1160,13 @@ def test_allgather_coalesced_checks(self): # One of output tensors does not match input list. dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)] with self.assertRaisesRegex( - ValueError, "invalid size of output tensor at index 0" + RuntimeError, "invalid size of output tensor at index 0" ): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) # One of output tensors does not match input list. dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)] - with self.assertRaisesRegex(ValueError, "invalid tensor type at index 0"): + with self.assertRaisesRegex(RuntimeError, "invalid tensor type at index 0"): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) # Output lists have too many elements @@ -1187,7 +1174,7 @@ def test_allgather_coalesced_checks(self): [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size + 1) ] with self.assertRaisesRegex( - ValueError, "output lists should be equal to world size" + RuntimeError, "output lists should be equal to world size" ): c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg) @@ -1207,26 +1194,26 @@ def test_reduce_checks(self): t1 = torch.zeros([1], dtype=torch.float32) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ReduceOptions() opts.rootRank = -1 opts.rootTensor = 0 pg.reduce([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root rank"): + with self.assertRaisesRegex(RuntimeError, "invalid root rank"): opts = c10d.ReduceOptions() opts.rootRank = self.world_size opts.rootTensor = 0 pg.reduce([t1], opts) - with self.assertRaisesRegex(ValueError, "invalid root tensor"): + with self.assertRaisesRegex(RuntimeError, "invalid root tensor"): opts = c10d.ReduceOptions() opts.rootRank = self.rank opts.rootTensor = 1 pg.reduce([t1], opts) with self.assertRaisesRegex( - ValueError, "requires a single-element tensor list" + RuntimeError, "requires a single-element tensor list" ): opts = c10d.ReduceOptions() opts.rootRank = self.rank @@ -1425,10 +1412,7 @@ class DistributedDataParallelTest( ): def setUp(self): super(DistributedDataParallelTest, self).setUp() - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() + self._spawn_processes() def _test_gloo_backend( self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False @@ -2197,10 +2181,7 @@ def test_forward_backward_optimizer(self): class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): def setUp(self): super(CommTest, self).setUp() - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() + self._spawn_processes() def tearDown(self): super(CommTest, self).tearDown() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2e5045153b149..1378aa07f0903 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -39,31 +39,20 @@ with_nccl_blocking_wait, ) from torch.testing._internal.common_utils import ( - IS_WINDOWS, TestCase, run_tests, retry_on_connect_failures, TEST_WITH_DEV_DBG_ASAN, - TEST_WITH_TSAN, + TEST_WITH_ROCM, sandcastle_skip, sandcastle_skip_if, ) from torch.utils.checkpoint import checkpoint +from torch.distributed.optim import functional_optim_map -if not IS_WINDOWS: - from torch.distributed.optim.functional_sgd import _FunctionalSGD - from torch.distributed.optim.functional_adam import _FunctionalAdam - _SUPPORTED_OPTIM_MAPPING = { - _FunctionalSGD: torch.optim.SGD, - _FunctionalAdam: torch.optim.Adam - } - -if TEST_WITH_TSAN: - print( - "Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", - file=sys.stderr, - ) - sys.exit(0) +from torch.distributed.optim.functional_sgd import _FunctionalSGD +from torch.distributed.optim.functional_adam import _FunctionalAdam +from torch.distributed.optim.functional_adamw import _FunctionalAdamW if TEST_WITH_DEV_DBG_ASAN: print( @@ -71,6 +60,11 @@ ) sys.exit(0) +# bfloat16 is only supported by CUDA 11+ +BFLOAT16_AVAILABLE = ( + torch.cuda.is_available() + and torch.version.cuda is not None + and int(torch.version.cuda.split('.')[0]) >= 11) class RendezvousEnvTest(TestCase): @retry_on_connect_failures @@ -1561,15 +1555,23 @@ def allreduce_hook( def _test_default_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False): """ - This unit test verifies whether default Python DDP communication hooks ALLREDUCE and FP16_COMPRESS - can give the same result with the case of no hook registered. + This unit test verifies whether default Python DDP communication hooks ALLREDUCE, FP16_COMPRESS + and BF16_COMPRESS, can give the same result with the case of no hook registered. """ store = c10d.FileStore(self.file_name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) # For these default DDP comm hooks, the only state is process group. state = process_group - for hook in [default.allreduce_hook, default.fp16_compress_hook]: + hook_options = [default.allreduce_hook, default.fp16_compress_hook] + if ( + not TEST_WITH_ROCM + and BFLOAT16_AVAILABLE + and c10d.is_nccl_available() + and torch.cuda.nccl.version() >= (2, 9, 7) + ): + hook_options.append(default.bf16_compress_hook) + for hook in hook_options: # Get GPU model with the hook registered. # The first arg 'process_group' is used for initializing the test environment, # so it cannot be replaced by 'state', although they have the same value. @@ -1605,6 +1607,31 @@ def _test_fp16_compress_wrapper(self, gradient_as_bucket_view=False): # check whether the grads are equal to what DDP without hook would return. self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + def _test_bf16_compress_wrapper(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with + the BF16_WRAPPER can give the same result as when there is no hook registered. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + powerSGD_state = powerSGD.PowerSGDState(process_group=process_group) + + hook_args = [ + (powerSGD.powerSGD_hook, powerSGD_state), + (default.allreduce_hook, process_group), + ] + + for hook, state in hook_args: + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, + default.bf16_compress_wrapper(hook), + gradient_as_bucket_view, + state, + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + def _test_hook_then_optimizer( self, functional_optim_cls, @@ -1637,7 +1664,8 @@ def _test_hook_then_optimizer( gpu_model_allreduce = self._gpu_model_with_ddp_comm_hook( process_group, default.allreduce_hook, gradient_as_bucket_view, hook_state ) - sgd = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)( + mapping = {v: k for k, v in functional_optim_map.items()} + sgd = mapping.get(functional_optim_cls)( gpu_model_allreduce.parameters(), *functional_optim_args, **functional_optim_kwargs, @@ -1710,6 +1738,17 @@ def test_default_ddp_comm_hooks_nccl(self): def test_fp16_compress_wrapper_nccl(self): self._test_fp16_compress_wrapper() + @requires_nccl() + @requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS") + @sandcastle_skip_if( + not BFLOAT16_AVAILABLE, + "BFloat16 is only supported by CUDA 11+", + ) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_bf16_compress_wrapper_nccl(self): + self._test_bf16_compress_wrapper() + @requires_nccl() @skip_if_lt_x_gpu(2) def test_hook_then_sgd_nccl(self): @@ -1737,6 +1776,20 @@ def test_hook_then_sgd_nccl_grad_as_bucket_view(self): gradient_as_bucket_view=True ) + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_hook_then_adamw_nccl(self): + adamw_lr = 1e-2 + adamw_betas = (0.9, 0.99) + adamw_eps = 1e-6 + self._test_hook_then_optimizer( + _FunctionalAdamW, + adamw_lr, + betas=adamw_betas, + eps=adamw_eps, + gradient_as_bucket_view=True + ) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_hook_then_adam_nccl(self): @@ -1795,6 +1848,17 @@ def test_default_ddp_comm_hooks_nccl_is_view(self): def test_fp16_compress_wrapper_is_view(self): self._test_fp16_compress_wrapper(gradient_as_bucket_view=True) + @requires_nccl() + @requires_nccl_version((2, 9, 7), "Need NCCL 2.9.7+ for BF16_COMPRESS") + @sandcastle_skip_if( + not BFLOAT16_AVAILABLE, + "BFloat16 is only supported by CUDA 11+", + ) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_bf16_compress_wrapper_is_view(self): + self._test_bf16_compress_wrapper(gradient_as_bucket_view=True) + @requires_nccl() @skip_if_lt_x_gpu(2) def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self): diff --git a/test/distributed/test_c10d_spawn_gloo.py b/test/distributed/test_c10d_spawn_gloo.py index 8e5e0519356cf..21f43f7ca95f6 100644 --- a/test/distributed/test_c10d_spawn_gloo.py +++ b/test/distributed/test_c10d_spawn_gloo.py @@ -11,7 +11,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_distributed import requires_gloo, \ create_device, MultiProcessTestCase, skip_if_lt_x_gpu -from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_TSAN, TEST_WITH_DEV_DBG_ASAN +from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 if sys.version_info < (3, 9): @@ -76,102 +76,100 @@ def test_shared_allgather_chunk_gloo(self): self.world_size) -# TSAN is not fork-safe since we're forking in a multi-threaded environment -if not TEST_WITH_TSAN: - class DistributedDataParallelSingleProcessTest(TestCase): - def setUp(self): - self.rank = 0 - self.world_size = 1 - self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201 - - def tearDown(self): - try: - os.remove(self.file.name) - except OSError: - pass - - def _test_base(self, net, inp, check_allclose=True): - store = c10d.FileStore(self.file.name, self.world_size) - process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) - if inp[0].is_cuda: - device_ids = [torch.cuda.current_device()] - else: - device_ids = None +class DistributedDataParallelSingleProcessTest(TestCase): + def setUp(self): + self.rank = 0 + self.world_size = 1 + self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201 - ddp = nn.parallel.DistributedDataParallel( - copy.deepcopy(net), - device_ids=device_ids, - process_group=process_group - ) + def tearDown(self): + try: + os.remove(self.file.name) + except OSError: + pass - net_opt = torch.optim.Adam(net.parameters(), lr=0.001) - ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001) + def _test_base(self, net, inp, check_allclose=True): + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size) + if inp[0].is_cuda: + device_ids = [torch.cuda.current_device()] + else: + device_ids = None - for i, j in zip(ddp.parameters(), net.parameters()): - self.assertTrue(i.allclose(j)) + ddp = nn.parallel.DistributedDataParallel( + copy.deepcopy(net), + device_ids=device_ids, + process_group=process_group + ) - for _ in range(10): - net_out = net(*inp) - ddp_out = ddp(*inp) + net_opt = torch.optim.Adam(net.parameters(), lr=0.001) + ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001) - net_out.sum().backward() - ddp_out.sum().backward() + for i, j in zip(ddp.parameters(), net.parameters()): + self.assertTrue(i.allclose(j)) - net_opt.step() - ddp_opt.step() + for _ in range(10): + net_out = net(*inp) + ddp_out = ddp(*inp) - if check_allclose: - for i, j in zip(ddp.parameters(), net.parameters()): - self.assertTrue(i.allclose(j)) + net_out.sum().backward() + ddp_out.sum().backward() - @requires_gloo() - def test_cpu(self): - self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)]) + net_opt.step() + ddp_opt.step() - @requires_gloo() - @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") - def test_cuda(self): - self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)]) + if check_allclose: + for i, j in zip(ddp.parameters(), net.parameters()): + self.assertTrue(i.allclose(j)) - @requires_gloo() - @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") - def test_rnn(self): - # This test is inspired by the bug reported in - # https://github.com/pytorch/pytorch/issues/36268 - BATCH_SIZE = 12 # Divisible by 2, 3, 4 - INPUT_DIM = 256 - OUTPUT_DIM = 256 - HIDDEN_DIM = 256 - N_LAYERS = 3 - SEQ_LEN = 100 - - class Net(nn.Module): - def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers): - super(Net, self).__init__() - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.output_dim = output_dim - self.hidden_layers = hidden_layers - - self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True) - self.h2o = nn.Linear(hidden_dim, output_dim) - - def forward(self, x, y): - self.lstm.flatten_parameters() - h_t, _ = self.lstm(x) - output = self.h2o(h_t) - loss = nn.functional.mse_loss(output, y) - return loss - - net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0) - inp = [ - torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0), - torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0) - ] - - # Not checking result allclose as the parameter inconsistency exist - # prior to this change. See #37079 - self._test_base(net, inp, check_allclose=False) + @requires_gloo() + def test_cpu(self): + self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)]) + + @requires_gloo() + @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") + def test_cuda(self): + self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)]) + + @requires_gloo() + @sandcastle_skip_if(not TEST_CUDA, "At least 1 CUDA GPUS needed") + def test_rnn(self): + # This test is inspired by the bug reported in + # https://github.com/pytorch/pytorch/issues/36268 + BATCH_SIZE = 12 # Divisible by 2, 3, 4 + INPUT_DIM = 256 + OUTPUT_DIM = 256 + HIDDEN_DIM = 256 + N_LAYERS = 3 + SEQ_LEN = 100 + + class Net(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers): + super(Net, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.hidden_layers = hidden_layers + + self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True) + self.h2o = nn.Linear(hidden_dim, output_dim) + + def forward(self, x, y): + self.lstm.flatten_parameters() + h_t, _ = self.lstm(x) + output = self.h2o(h_t) + loss = nn.functional.mse_loss(output, y) + return loss + + net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0) + inp = [ + torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0), + torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0) + ] + + # Not checking result allclose as the parameter inconsistency exist + # prior to this change. See #37079 + self._test_base(net, inp, check_allclose=False) # Skip dev-asan as torch + multiprocessing spawn have known issues diff --git a/test/distributed/test_distributed_fork.py b/test/distributed/test_distributed_fork.py deleted file mode 100644 index c707a313a5e47..0000000000000 --- a/test/distributed/test_distributed_fork.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import sys -import tempfile -from functools import wraps -import torch -import torch.cuda -import torch.distributed as dist -from torch.testing._internal.common_utils import TEST_WITH_TSAN - -if not dist.is_available(): - print("Distributed not available, skipping tests", file=sys.stderr) - sys.exit(0) - -from torch.testing._internal.common_utils import TestCase, find_free_port, run_tests -from torch.distributed.distributed_c10d import _get_default_group -from torch.testing._internal.distributed.distributed_test import ( - DistributedTest, TestDistBackend -) - -torch.backends.cuda.matmul.allow_tf32 = False - -CPP_EXTENSIONS_WARNING = """ -Ninja (https://ninja-build.org) must be available to run C++ extensions tests, -but it could not be found. Install ninja with `pip install ninja` -or `conda install ninja`. -""" - -BACKEND = os.environ["BACKEND"] -INIT_METHOD = os.getenv("INIT_METHOD", "env://") - - -def skip_if_no_ninja(func): - - @wraps(func) - def wrapper(*args, **kwargs): - try: - import torch.utils.cpp_extension - torch.utils.cpp_extension.verify_ninja_availability() - except RuntimeError: - print(CPP_EXTENSIONS_WARNING) - return 0 - - return func(*args, **kwargs) - - return wrapper - -if TEST_WITH_TSAN: - print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr) - sys.exit(0) - -if BACKEND == "gloo" or BACKEND == "nccl": - - class TestDistBackendWithFork(TestDistBackend, DistributedTest._DistTestBase): - - def setUp(self): - super().setUp() - self._fork_processes() - torch.backends.cudnn.flags(allow_tf32=False).__enter__() - - -elif BACKEND == "mpi": - WORLD_SIZE = os.environ["WORLD_SIZE"] - dist.init_process_group(init_method=INIT_METHOD, backend="mpi") - - class TestMPIWithFork(TestCase, DistributedTest._DistTestBase): - pass - -elif BACKEND == "test": - class TestBackendDynamicLoad(TestCase): - def setUp(self): - super(TestBackendDynamicLoad, self).setUp() - - def _load_test_backend(self): - temp_dir = tempfile.mkdtemp() - src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__))) - extension = torch.utils.cpp_extension.load( - name="torch_test", - sources=[src], - build_directory=temp_dir - ) - - @skip_if_no_ninja - def test_backend_apis(self): - self._load_test_backend() - - os.environ['WORLD_SIZE'] = '1' - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(find_free_port()) - os.environ['RANK'] = '0' - - dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0) - self.assertEqual(dist.get_rank(), 0) - self.assertEqual(dist.get_world_size(), 1) - - process_group = _get_default_group() - work = process_group.allreduce([torch.rand(1), torch.rand(1)]) - self.assertTrue(work.wait()) - self.assertTrue(work.is_completed()) - self.assertTrue(work.is_success()) - - work = process_group.broadcast([torch.rand(1)]) - self.assertTrue(work.wait()) - self.assertTrue(work.is_completed()) - self.assertTrue(work.is_success()) - - dist.destroy_process_group() - -if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" - - run_tests() diff --git a/test/distributed/test_jit_c10d.py b/test/distributed/test_jit_c10d.py index be392730b3fad..65d82fb033b7d 100644 --- a/test/distributed/test_jit_c10d.py +++ b/test/distributed/test_jit_c10d.py @@ -6,7 +6,7 @@ from typing import List from torch.testing._internal.common_distributed import requires_nccl, create_tcp_store -from torch.testing._internal.common_utils import load_tests, TEST_WITH_TSAN, run_tests, sandcastle_skip_if +from torch.testing._internal.common_utils import load_tests, run_tests, sandcastle_skip_if from torch.testing._internal.jit_utils import JitTestCase # load_tests from common_utils is used to automatically filter tests for @@ -29,10 +29,6 @@ def unique_process_group_name(prefix): now = int(time.time() * 1000) return "%s_%d" % (prefix, now) -if TEST_WITH_TSAN: - print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr) - sys.exit(0) - class ProcessGroupNCCLJitTest(JitTestCase): MAIN_PROCESS_RANK = 0 diff --git a/test/distributed/test_launcher.py b/test/distributed/test_launcher.py index 85ba293966f2d..422c88b6bdee5 100644 --- a/test/distributed/test_launcher.py +++ b/test/distributed/test_launcher.py @@ -11,8 +11,7 @@ sys.exit(0) from torch.testing._internal.common_utils import ( - TEST_WITH_ASAN, - TEST_WITH_TSAN, + TEST_WITH_DEV_DBG_ASAN, TestCase, run_tests, ) @@ -21,14 +20,14 @@ def path(script): return os.path.join(os.path.dirname(__file__), script) -if TEST_WITH_ASAN: - print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr) - sys.exit(0) -if TEST_WITH_TSAN: - print("Skip as TSAN is not fork-safe since we're forking in a multi-threaded environment", file=sys.stderr) +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr + ) sys.exit(0) + class TestDistributedLaunch(TestCase): def test_launch_user_script(self): nnodes = 1 @@ -41,7 +40,7 @@ def test_launch_user_script(self): f"--nnodes={nnodes}", f"--nproc_per_node={nproc_per_node}", "--monitor_interval=1", - "--start_method=fork", + "--start_method=spawn", "--master_addr=localhost", f"--master_port={master_port}", "--node_rank=0", diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index de3a66712bffe..abf77d4fdaa02 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -20,7 +20,6 @@ ) from torch.testing._internal.common_utils import ( run_tests, - TEST_WITH_TSAN, TEST_WITH_DEV_DBG_ASAN, ) @@ -28,11 +27,7 @@ class AbstractProcessGroupWrapperTest(MultiProcessTestCase): def setUp(self): super(AbstractProcessGroupWrapperTest, self).setUp() - # For Windows platform, Python does not support fork, change it to spawn here. - if sys.platform == "win32": - self._spawn_processes() - else: - self._fork_processes() + self._spawn_processes() def _validate_error(self, exception, op_type, rank, tensor): err = str(exception) @@ -291,91 +286,89 @@ def test_collective_shape_mismatch(self): self._test_collective_shape_mismatch(pg, use_cuda=True) -# TSAN is not fork-safe since we're forking in a multi-threaded environment -if not TEST_WITH_TSAN: - @requires_gloo() - class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): - def setUp(self): - super(ProcessGroupGlooWrapperTest, self).setUp() - - def opts(self, threads=2, timeout=10.0): - opts = c10d.ProcessGroupGloo._Options() - opts._timeout = timeout - opts._devices = [create_device(interface=LOOPBACK)] - opts._threads = threads - return opts - - def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): - store = c10d.FileStore(self.file_name, self.world_size) - c10d.init_process_group( - backend="gloo", rank=self.rank, world_size=self.world_size, store=store +@requires_gloo() +class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): + def setUp(self): + super(ProcessGroupGlooWrapperTest, self).setUp() + + def opts(self, threads=2, timeout=10.0): + opts = c10d.ProcessGroupGloo._Options() + opts._timeout = timeout + opts._devices = [create_device(interface=LOOPBACK)] + opts._threads = threads + return opts + + def _create_wrapper_pg(self, with_new_group=False, timeout=10.0): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="gloo", rank=self.rank, world_size=self.world_size, store=store + ) + if with_new_group: + pg = c10d.new_group(backend="gloo") + else: + _pg = c10d.ProcessGroupGloo( + store, self.rank, self.world_size, self.opts(timeout=timeout) ) - if with_new_group: - pg = c10d.new_group(backend="gloo") - else: - _pg = c10d.ProcessGroupGloo( - store, self.rank, self.world_size, self.opts(timeout=timeout) - ) - pg = c10d._create_process_group_wrapper( - _pg, - "unused", - store, - self.rank, - self.world_size, - timeout=timeout, - ) - return pg - - def test_collective_hang(self): - pg = self._create_wrapper_pg(timeout=2.0) - self._test_collective_hang(pg) - - # NOTE: these tests are separated by debug level instead of combined into - # one due to https://github.com/pytorch/pytorch/issues/55967, they can be - # combined after that is resolved. - @with_dist_debug_levels(levels=["DETAIL"]) - def test_collectives_op_mismatch_debug_mode(self): - pg = self._create_wrapper_pg(with_new_group=True) - self._test_collectives_op_mismatch(pg) - - @with_dist_debug_levels(levels=["OFF"]) - def test_collectives_op_mismatch(self): - pg = self._create_wrapper_pg(with_new_group=False) - self._test_collectives_op_mismatch(pg) - - @with_dist_debug_levels(levels=["DETAIL"]) - def test_collective_shape_mismatch_debug_mode(self): - pg = self._create_wrapper_pg(with_new_group=True) - self._test_collective_shape_mismatch(pg) - - @with_dist_debug_levels(levels=["OFF"]) - def test_collective_shape_mismatch(self): - pg = self._create_wrapper_pg(with_new_group=False) - self._test_collective_shape_mismatch(pg) - - @skip_if_lt_x_gpu(4) - @with_dist_debug_levels(levels=["DETAIL"]) - def test_collectives_op_mismatch_cuda_debug_mode(self): - pg = self._create_wrapper_pg(with_new_group=True) - self._test_collectives_op_mismatch(pg, use_cuda=True) - - @skip_if_lt_x_gpu(4) - @with_dist_debug_levels(levels=["OFF"]) - def test_collectives_op_mismatch_cuda(self): - pg = self._create_wrapper_pg(with_new_group=False) - self._test_collectives_op_mismatch(pg, use_cuda=True) - - @skip_if_lt_x_gpu(4) - @with_dist_debug_levels(levels=["DETAIL"]) - def test_collective_shape_mismatch_cuda_debug_mode(self): - pg = self._create_wrapper_pg(with_new_group=True) - self._test_collective_shape_mismatch(pg, use_cuda=True) - - @skip_if_lt_x_gpu(4) - @with_dist_debug_levels(levels=["OFF"]) - def test_collective_shape_mismatch_cuda(self): - pg = self._create_wrapper_pg(with_new_group=False) - self._test_collective_shape_mismatch(pg, use_cuda=True) + pg = c10d._create_process_group_wrapper( + _pg, + "unused", + store, + self.rank, + self.world_size, + timeout=timeout, + ) + return pg + + def test_collective_hang(self): + pg = self._create_wrapper_pg(timeout=2.0) + self._test_collective_hang(pg) + + # NOTE: these tests are separated by debug level instead of combined into + # one due to https://github.com/pytorch/pytorch/issues/55967, they can be + # combined after that is resolved. + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collectives_op_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collectives_op_mismatch(pg) + + @with_dist_debug_levels(levels=["OFF"]) + def test_collectives_op_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collectives_op_mismatch(pg) + + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collective_shape_mismatch_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collective_shape_mismatch(pg) + + @with_dist_debug_levels(levels=["OFF"]) + def test_collective_shape_mismatch(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collective_shape_mismatch(pg) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collectives_op_mismatch_cuda_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["OFF"]) + def test_collectives_op_mismatch_cuda(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collectives_op_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["DETAIL"]) + def test_collective_shape_mismatch_cuda_debug_mode(self): + pg = self._create_wrapper_pg(with_new_group=True) + self._test_collective_shape_mismatch(pg, use_cuda=True) + + @skip_if_lt_x_gpu(4) + @with_dist_debug_levels(levels=["OFF"]) + def test_collective_shape_mismatch_cuda(self): + pg = self._create_wrapper_pg(with_new_group=False) + self._test_collective_shape_mismatch(pg, use_cuda=True) if __name__ == "__main__": diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 85e4dbacd4b6a..319b55795addb 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -387,6 +387,12 @@ def is_all_nan(tensor): }, { 'rate': 0.2, + }, + { + 'rate': torch.tensor([0.0], requires_grad=True), + }, + { + 'rate': 0.0, } ]), Example(RelaxedBernoulli, [ @@ -667,7 +673,7 @@ def is_all_nan(tensor): ]), Example(Poisson, [ { - 'rate': torch.tensor([0.0], requires_grad=True), + 'rate': torch.tensor([-0.1], requires_grad=True), }, { 'rate': -1.0, @@ -1315,17 +1321,29 @@ def test_poisson_shape(self): def test_poisson_log_prob(self): rate = torch.randn(2, 3).abs().requires_grad_() rate_1d = torch.randn(1).abs().requires_grad_() + rate_zero = torch.zeros([], requires_grad=True) - def ref_log_prob(idx, x, log_prob): - l = rate.view(-1)[idx].detach() + def ref_log_prob(ref_rate, idx, x, log_prob): + l = ref_rate.view(-1)[idx].detach() expected = scipy.stats.poisson.logpmf(x, l) self.assertEqual(log_prob, expected, atol=1e-3, rtol=0) set_rng_seed(0) - self._check_log_prob(Poisson(rate), ref_log_prob) + self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args)) + self._check_log_prob(Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args)) self._gradcheck_log_prob(Poisson, (rate,)) self._gradcheck_log_prob(Poisson, (rate_1d,)) + # We cannot check gradients automatically for zero rates because the finite difference + # approximation enters the forbidden parameter space. We instead compare with the + # theoretical results. + dist = Poisson(rate_zero) + s = dist.sample() + dist.log_prob(s).backward() + torch.testing.assert_allclose(rate_zero.grad, -1.0) + dist.log_prob(torch.ones_like(rate_zero)).backward() + torch.testing.assert_allclose(rate_zero.grad, torch.inf) + @unittest.skipIf(IS_MACOS, "See https://github.com/pytorch/pytorch/issues/60347") @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_poisson_sample(self): diff --git a/test/expect/TestAutograd.test_function-x_grad_desc.expect b/test/expect/TestAutograd.test_function-x_grad_desc.expect index b6fdb63db272a..68242e2ffae90 100644 --- a/test/expect/TestAutograd.test_function-x_grad_desc.expect +++ b/test/expect/TestAutograd.test_function-x_grad_desc.expect @@ -1 +1 @@ -CopyBackwards(None, AddBackward0(ExpandBackward(AccumulateGrad()), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad()))) \ No newline at end of file +CopyBackwards(None, AddBackward0(ExpandBackward0(AccumulateGrad()), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad()))) \ No newline at end of file diff --git a/test/expect/TestAutograd.test_function-y_grad_desc.expect b/test/expect/TestAutograd.test_function-y_grad_desc.expect index e32d5888e1e7a..88db87320a92e 100644 --- a/test/expect/TestAutograd.test_function-y_grad_desc.expect +++ b/test/expect/TestAutograd.test_function-y_grad_desc.expect @@ -1 +1 @@ -CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward(AccumulateGrad()), None), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad()))) \ No newline at end of file +CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward0(AccumulateGrad()), None), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad()))) \ No newline at end of file diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect new file mode 100644 index 0000000000000..5c3630a3169f7 --- /dev/null +++ b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect @@ -0,0 +1,19 @@ +torch.fx._symbolic_trace.ProxyableClassMeta [] +torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace'] +torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'flatten_inps', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'output', 'owning_module', 'placeholder', 'print_tabular', 'python_code', 'unflatten_outs'] +torch.fx.graph.PythonCode [] +torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder'] +torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update'] +torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove'] +torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node'] +torch.fx.interpreter.Transformer ['call_function', 'call_module', 'get_attr', 'placeholder', 'transform'] +torch.fx.node.Node ['all_input_nodes', 'append', 'args', 'format_node', 'is_impure', 'kwargs', 'next', 'normalized_arguments', 'prepend', 'prev', 'replace_all_uses_with', 'replace_input_with', 'stack_trace', 'update_arg', 'update_kwarg'] +torch.fx.passes.shape_prop.ShapeProp ['propagate', 'run_node'] +torch.fx.passes.shape_prop.TensorMetadata ['dtype', 'is_quantized', 'memory_format', 'q_scale', 'q_zero_point', 'qscheme', 'requires_grad', 'shape', 'stride'] +torch.fx.passes.split_module.Partition [] +torch.fx.proxy.Attribute ['node'] +torch.fx.proxy.GraphAppendingTracer [] +torch.fx.proxy.Proxy ['keys'] +torch.fx.proxy.TraceError [] +torch.fx.proxy.TracerBase ['check_mutable_operations', 'create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'record_stack_traces', 'to_bool'] +torch.fx.subgraph_rewriter.Match ['anchor', 'nodes_map'] \ No newline at end of file diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect new file mode 100644 index 0000000000000..20d392fa9cbb1 --- /dev/null +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -0,0 +1,73 @@ +torch.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (,), autowrap_functions: Tuple[Callable, ...] = (,), enable_cpatching: bool = False, param_shapes_constant: bool = False) -> None +torch.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any +torch.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument' +torch.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module.Module, module_qualified_name: str) -> bool +torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str +torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph +torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, enable_cpatching: bool = False) -> torch.fx.graph_module.GraphModule +torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable]) +torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None) +torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.eliminate_dead_code(self) +torch.fx.graph.Graph.erase_node(self, to_erase: torch.fx.node.Node) -> None +torch.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[torch.fx.node.Node, torch.fx.node.Node], return_output_node = False) -> 'Optional[Argument]' +torch.fx.graph.Graph.inserting_after(self, n: Optional[torch.fx.node.Node] = None) +torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = None) +torch.fx.graph.Graph.lint(self) +torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = >) -> torch.fx.node.Node +torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) +torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.graph.Graph.print_tabular(self) +torch.fx.graph.Graph.python_code(self, root_module: str) -> torch.fx.graph.PythonCode +torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') +torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool +torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None +torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool +torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode +torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module +torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module +torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module +torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True) +torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: torch.fx.node.Node) -> Tuple[Tuple, Dict] +torch.fx.interpreter.Interpreter.fetch_attr(self, target: str) +torch.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.map_nodes_to_values(self, args: torch.fx.node.Argument, n: torch.fx.node.Node) -> torch.fx.node.Argument +torch.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[torch.fx.node.Node, Any]] = None) -> Any +torch.fx.interpreter.Interpreter.run_node(self, n: torch.fx.node.Node) -> Any +torch.fx.interpreter.Transformer.__init__(self, module) +torch.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any +torch.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy +torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy +torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule +torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None +torch.fx.node.Node.append(self, x: 'Node') -> None +torch.fx.node.Node.format_node(self, placeholder_names: List[str] = None, maybe_return_typename: List[str] = None) -> Optional[str] +torch.fx.node.Node.prepend(self, x: 'Node') -> None +torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node') -> List[Node] +torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') +torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None +torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None +torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument +torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument +torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int]) +torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) +torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) +torch.fx.proxy.Proxy.keys(self) +torch.fx.proxy.TracerBase.create_arg(self, a: Any) -> torch.fx.node.Argument +torch.fx.proxy.TracerBase.create_node(self, kind: str, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, torch.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node +torch.fx.proxy.TracerBase.create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[torch.fx.node.Node], Proxy] = None) +torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator +torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any +torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy' +torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool +torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match] \ No newline at end of file diff --git a/test/expect/TestScript.test_annot_ast_mypy_fn.expect b/test/expect/TestScript.test_annot_ast_mypy_fn.expect index 4b15b27b48112..36888d04876ef 100644 --- a/test/expect/TestScript.test_annot_ast_mypy_fn.expect +++ b/test/expect/TestScript.test_annot_ast_mypy_fn.expect @@ -6,4 +6,4 @@ foo(bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_ast_mypy_method.expect b/test/expect/TestScript.test_annot_ast_mypy_method.expect index 9c0dcd14deeec..b6c19a6002483 100644 --- a/test/expect/TestScript.test_annot_ast_mypy_method.expect +++ b/test/expect/TestScript.test_annot_ast_mypy_method.expect @@ -6,4 +6,4 @@ foo( self, bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo( self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo( self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo( self, int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_ast_py3_fn.expect b/test/expect/TestScript.test_annot_ast_py3_fn.expect index 4b15b27b48112..36888d04876ef 100644 --- a/test/expect/TestScript.test_annot_ast_py3_fn.expect +++ b/test/expect/TestScript.test_annot_ast_py3_fn.expect @@ -6,4 +6,4 @@ foo(bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_ast_py3_method.expect b/test/expect/TestScript.test_annot_ast_py3_method.expect index 9c0dcd14deeec..b6c19a6002483 100644 --- a/test/expect/TestScript.test_annot_ast_py3_method.expect +++ b/test/expect/TestScript.test_annot_ast_py3_method.expect @@ -6,4 +6,4 @@ foo( self, bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo( self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo( self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo( self, int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_string_mypy_fn.expect b/test/expect/TestScript.test_annot_string_mypy_fn.expect index 4b15b27b48112..36888d04876ef 100644 --- a/test/expect/TestScript.test_annot_string_mypy_fn.expect +++ b/test/expect/TestScript.test_annot_string_mypy_fn.expect @@ -6,4 +6,4 @@ foo(bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_string_mypy_method.expect b/test/expect/TestScript.test_annot_string_mypy_method.expect index 9c0dcd14deeec..b6c19a6002483 100644 --- a/test/expect/TestScript.test_annot_string_mypy_method.expect +++ b/test/expect/TestScript.test_annot_string_mypy_method.expect @@ -6,4 +6,4 @@ foo( self, bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo( self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo( self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo( self, int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_string_py3_fn.expect b/test/expect/TestScript.test_annot_string_py3_fn.expect index 4b15b27b48112..36888d04876ef 100644 --- a/test/expect/TestScript.test_annot_string_py3_fn.expect +++ b/test/expect/TestScript.test_annot_string_py3_fn.expect @@ -6,4 +6,4 @@ foo(bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo(float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo(int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo(int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo(int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/expect/TestScript.test_annot_string_py3_method.expect b/test/expect/TestScript.test_annot_string_py3_method.expect index 9c0dcd14deeec..b6c19a6002483 100644 --- a/test/expect/TestScript.test_annot_string_py3_method.expect +++ b/test/expect/TestScript.test_annot_string_py3_method.expect @@ -6,4 +6,4 @@ foo( self, bool x, (Tensor, Tensor) y) -> ((bool, bool)) foo( self, float[3] x, (Tensor, Tensor) y) -> ((float[], float[])) foo( self, int[2] x, (Tensor, Tensor) y) -> ((int[], int[])) foo( self, int[] x, (Tensor, Tensor) y) -> ((int[], int[])) -foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) \ No newline at end of file +foo( self, int? x, (Tensor, Tensor) y) -> ((int?, int?)) diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py index 203cf6b7e306e..37e8db1e5cf4b 100644 --- a/test/fx/test_gradual_type.py +++ b/test/fx/test_gradual_type.py @@ -9,7 +9,14 @@ from torch.fx.experimental.rewriter import RewritingTracer from torch.fx import GraphModule from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.experimental.unification import Var + +try: + import sympy + HAS_SYMPY = True +except ImportError: + HAS_SYMPY = False +skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") + try: from torchvision.models import resnet50 @@ -19,13 +26,6 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -# try: -# from unification import Var -# HAS_UNIFICATION = True -# except ImportError: -# HAS_UNIFICATION = False -# skipIfNoUnification = unittest.skipIf(not HAS_UNIFICATION, "no unification") - def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, @@ -270,10 +270,9 @@ def forward(self, x: TensorType((1, 2, 3, 5))): def test_type_check_batch_norm_2D(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, norm_layer=None): + def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: TensorType((2, 2, 5, 4))): @@ -302,10 +301,9 @@ def forward(self, x: TensorType((2, 2, 5, 4))): def test_type_check_batch_norm_2D_false(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, norm_layer=None): + def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: TensorType((2, 2, 5))): @@ -325,10 +323,9 @@ def forward(self, x: TensorType((2, 2, 5))): def test_type_check_batch_norm_2D_broadcast(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, norm_layer=None): + def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: Dyn): @@ -363,10 +360,9 @@ def forward(self, x: Dyn): def test_type_check_conv2D(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1, norm_layer=None): + def __init__(self, inplanes, planes, stride=1): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) @@ -394,10 +390,9 @@ def forward(self, x: Dyn): def test_type_check_conv2D_2(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, stride=1, norm_layer=None): + def __init__(self, inplanes, planes, stride=1): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) @@ -434,7 +429,6 @@ def forward(self, x: TensorType((5, 2, 3, 4))): with self.assertRaises(TypeError): tc.type_check() - def test_type_check_conv2D_2_fully_static(self): annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)] @@ -522,16 +516,14 @@ def forward(self, x): assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size())) - def test_typecheck_basicblock(self): class BasicBlock(torch.nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + base_width=64, dilation=1): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: @@ -643,7 +635,6 @@ def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))): if n.op == 'output': assert n.type == TensorType((1, Dyn, 5, Dyn)) - def test_type_check_flatten3(self): class M(torch.nn.Module): def forward(self, x: TensorType((2, 3, 4, 5))): @@ -661,7 +652,6 @@ def forward(self, x: TensorType((2, 3, 4, 5))): c = r.constraints assert c == [Equality(2, 2)] - def test_type_typechecl_maxpool2d_3dinput(self): class BasicBlock(torch.nn.Module): @@ -770,7 +760,6 @@ def forward(self, x): assert n.type == TensorType(output_types[i]) assert is_consistent(n.type, TensorType(b.size())) - def test_flatten_fully_static(self): annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))] @@ -816,6 +805,7 @@ def forward(self, x): if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) + @skipIfNoSympy @skipIfNoTorchVision def test_resnet50(self): gm_run = symbolic_trace(resnet50()) @@ -859,14 +849,13 @@ def test_resnet50(self): batch_sizes.add(n.type.__args__[0]) assert (len(batch_sizes) == 1) - + @skipIfNoSympy def test_type_check_batch_norm_symbolic(self): class BasicBlock(torch.nn.Module): - def __init__(self, inplanes, planes, norm_layer=None): + def __init__(self, inplanes, planes): super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = torch.nn.BatchNorm2d + norm_layer = torch.nn.BatchNorm2d self.bn1 = norm_layer(planes) def forward(self, x: Dyn): @@ -884,15 +873,15 @@ def forward(self, x: Dyn): infer_symbolic_types(traced) - - my_types = iter([TensorType[(2, 2, Var(7), 4)], - TensorType[(2, 2, Var(7), 4)], - TensorType[(2, 2, Var(7), 4)], - TensorType[(2, 2, Var(7), 4)]]) + my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)], + TensorType[(2, 2, sympy.symbols('~7'), 4)], + TensorType[(2, 2, sympy.symbols('~7'), 4)], + TensorType[(2, 2, sympy.symbols('~7'), 4)]]) for n in graph.nodes: assert n.type == next(my_types) + @skipIfNoSympy def test_symbolic_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): @@ -911,16 +900,17 @@ def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): infer_symbolic_types(symbolic_traced) - expected_ph_types = [TensorType((1, 2, 3, Var(0))), + expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))), TensorType((2, 3, 4)), - TensorType((1, 2, 3, Var(1))), - TensorType((1, 2, 3, Var(1)))] + TensorType((1, 2, 3, sympy.symbols('~1'))), + TensorType((1, 2, 3, sympy.symbols('~1')))] expected_iter = iter(expected_ph_types) + for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter) - + @skipIfNoSympy def test_symbolic_add_with_broadcast_2(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): @@ -934,13 +924,80 @@ def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): r.refine() expected_ph_types = [TensorType((1, 2)), - TensorType((Var(1), 2)), - TensorType((Var(1), 2)), - TensorType((Var(1), 2))] + TensorType((sympy.symbols('~1'), 2)), + TensorType((sympy.symbols('~1'), 2)), + TensorType((sympy.symbols('~1'), 2))] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter) + @skipIfNoSympy + def test_type_check_conv2D_types(self): + class BasicBlock(torch.nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + norm_layer = torch.nn.BatchNorm2d + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + + def forward(self, x: Dyn): + identity = x + out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) + out += identity + return out + + B = BasicBlock(2, 2) + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(B) + traced = GraphModule(ast_rewriter.root, graph, "gm") + tc = GraphTypeChecker({}, traced) + tc.type_check() + infer_symbolic_types(traced) + + for n in traced.graph.nodes: + if n.op == 'call_module': + assert isinstance(n.type.__args__[2], sympy.floor) + assert isinstance(n.type.__args__[3], sympy.floor) + + @skipIfNoSympy + def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): + + class BasicBlock(torch.nn.Module): + def __init__(self): + super(BasicBlock, self).__init__() + + self.conv1 = torch.nn.Conv2d(3, 6, 5) + self.pool = torch.nn.MaxPool2d(2, 2) + self.conv2 = torch.nn.Conv2d(6, 16, 5) + self.fc1 = torch.nn.Linear(5, 120) + self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7)) + + def forward(self, x : TensorType((4, 3, Dyn, Dyn))): + out = self.conv1(x) + out = self.pool(out) + out = self.conv2(out) + out = self.pool(out) + out = self.fc1(out) + out = self.pool2(out) + out = torch.flatten(out, 1) + return out + + B = BasicBlock() + ast_rewriter = RewritingTracer() + traced = symbolic_trace(B) + tc = GraphTypeChecker({}, traced) + tc.type_check() + infer_symbolic_types(traced) + + for n in traced.graph.nodes: + if n.target == 'conv1': + assert n.type == TensorType((4, 6, sympy.floor((sympy.symbols('~0') - 4)), + sympy.floor((sympy.symbols('~1') - 4)))) + + elif n.target == 'conv2': + assert n.type == TensorType((4, 16, sympy.floor((sympy.symbols('~4') - 4)), + sympy.floor((sympy.symbols('~5') - 4)))) + if __name__ == '__main__': unittest.main() diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index 1e2037e59a0ba..e9317b11412a9 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -1877,7 +1877,7 @@ def forward(self, x): N, C, H, W, = 10, 3, 224, 224 inp = torch.randn(N, C, H, W) self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) - self.assertTrue(torch.allclose(model(inp), mod(inp))) + self.assertEqual(model(inp), mod(inp)) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") def test_pool2d_batchnorm(self): @@ -1901,7 +1901,7 @@ def test_pool2d_batchnorm(self): self.run_pass('dce', mod.graph) self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) - self.assertTrue(torch.allclose(sub_model(inp), mod(inp))) + self.assertEqual(sub_model(inp), mod(inp)) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") def test_pool3d_batchnorm(self): @@ -1925,7 +1925,46 @@ def test_pool3d_batchnorm(self): self.run_pass('dce', mod.graph) self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) - self.assertTrue(torch.allclose(sub_model(inp), mod(inp))) + self.assertEqual(sub_model(inp), mod(inp)) + + @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") + @skipIfNoTorchVision + def test_layernorm(self): + with set_default_dtype(torch.float): + + class ResidualLayernorm(torch.nn.Module): + def __init__(self, op, layernorm, **kwargs): + super(ResidualLayernorm, self).__init__() + self.op = op + self.layernorm = layernorm + + def forward(self, x): + y = self.op(x) + return self.layernorm(y) + y + + model = torchvision.models.resnet18() + N, C, H, W, = 10, 3, 224, 224 + for param in ((model.conv1, [W // 2], torch.randn(N, C, H, W)), + (model.conv1, [H // 2, W // 2], torch.randn(N, C, H, W)), + (torch.nn.Linear(H, W), [W], torch.randn(N, C, W)),): + + for layernorm in (torch.nn.LayerNorm(param[1]), + torch.nn.LayerNorm(param[1], elementwise_affine=False)): + # to generate non inplace tests we extend the use of layernorm's input + for inplace in (True, False): + sub_model = torch.nn.Sequential(param[0], layernorm) if inplace else ResidualLayernorm(param[0], layernorm) + sub_model.eval() + mod = torch.jit.freeze(torch.jit.script(sub_model)) + self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) + # if weight and bias are present and shape is the last dimension + # we should convert `aten::layer_norm` to `prim::MKLDNNLayerNorm` + if layernorm.elementwise_affine and len(param[1]) == 1: + inplace_suffix = "_" if inplace else "" + (FileCheck().check("prim::MKLDNNLayerNorm" + inplace_suffix). + check_count("aten::to_dense", 1, exactly=True).run(mod.graph)) + else: + FileCheck().check_count("aten::to_dense", 1, exactly=True).check("aten::layer_norm").run(mod.graph) + self.assertEqual(sub_model(param[2]), mod(param[2]), rtol=1e-04, atol=1e-04) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") @skipIfNoTorchVision @@ -1940,6 +1979,7 @@ def __init__(self, min_val, max_val, **kwargs): def forward(self, x): return torch.clamp(x, self.min_val, self.max_val) + N, C, H, W, = 10, 3, 224, 224 activations = [ torch.nn.Hardswish(), torch.nn.Hardsigmoid(), @@ -1960,11 +2000,10 @@ def forward(self, x): sub_model = torch.nn.Sequential(model.conv1, activation) sub_model.eval() mod = torch.jit.freeze(torch.jit.script(sub_model)) - N, C, H, W, = 10, 3, 224, 224 inp = torch.randn(N, C, H, W) self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) FileCheck().check_count("aten::to_dense", 1, exactly=True).run(mod.graph) - self.assertTrue(torch.allclose(sub_model(inp), mod(inp))) + self.assertEqual(sub_model(inp), mod(inp)) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") def test_hardswish_hardsigmoid(self): @@ -1991,7 +2030,7 @@ def test_hardswish_hardsigmoid(self): x = torch.rand(size) # `inplace=False` is intentional, otherwise we modify the input # and we aren't testing aten impls anyways - self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense())) + self.assertEqual(aten_op(x, inplace=False), m(x).to_dense()) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") def test_scalar_mul(self): diff --git a/test/jit/test_ignorable_args.py b/test/jit/test_ignorable_args.py index b195e3cc4faaa..fb63c1973bf0e 100644 --- a/test/jit/test_ignorable_args.py +++ b/test/jit/test_ignorable_args.py @@ -1,5 +1,6 @@ import os import sys +import torch from torch._C import parse_ir from torch.testing import FileCheck @@ -43,3 +44,9 @@ def test_slice_ignorable_args_for_slice(self): # because in %16, %15 and %0 are default values for the schema. FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src) self.assertEqual(function(), function_copy()) + + def test_add_out_ignorable_args(self): + @torch.jit.script + def fn(x: torch.Tensor, y: torch.Tensor): + torch.add(x, y, out=y) + FileCheck().check("torch.add(x, y, out=y)").run(fn.code) diff --git a/test/jit/test_isinstance.py b/test/jit/test_isinstance.py index 93b2605748516..5fd2b87965607 100644 --- a/test/jit/test_isinstance.py +++ b/test/jit/test_isinstance.py @@ -310,3 +310,12 @@ def fn(x: Any): x: int = 2 fn(x) self.assertEqual(len(w), 0) + + def test_empty_container_special_cases(self): + # Should not throw "Boolean value of Tensor with no values is + # ambiguous" error + torch._jit_internal.check_empty_containers(torch.Tensor([])) + + # Should not throw "Boolean value of Tensor with more than + # one value is ambiguous" error + torch._jit_internal.check_empty_containers(torch.rand(2, 3)) diff --git a/test/jit/test_jit_utils.py b/test/jit/test_jit_utils.py index 11d974bfe64c4..b344f82e96ced 100644 --- a/test/jit/test_jit_utils.py +++ b/test/jit/test_jit_utils.py @@ -77,3 +77,18 @@ def fn_hybrid_args(x, /, y, *args, **kwargs): self.assertEqual( [], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)) + + def test_checkscriptassertraisesregex(self): + def fn(): + tup = (1, 2) + return tup[2] + + self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn") + + s = dedent(""" + def fn(): + tup = (1, 2) + return tup[2] + """) + + self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn") diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index d8434515291ab..10f5e879099a0 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -92,7 +92,7 @@ def reassign_from_empty_literal(): if 1 == 1: x = [1, 2, 3] return - with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously has type List\[Tensor\]", "x"): + with self.assertRaisesRegexWithHighlight(RuntimeError, r"previously had type List\[Tensor\]", "x"): self.checkScript(reassign_from_empty_literal, (), optimize=False) def reassign_from_empty_builtin(): @@ -113,7 +113,7 @@ def reassign_bad_type(): if 1 == 1: x = [1.0] return - with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"): self.checkScript(reassign_bad_type, (), optimize=False) def reassign_nested(): @@ -123,7 +123,7 @@ def reassign_nested(): if 1 == 1: x = [1.0] return - with self.assertRaisesRegexWithHighlight(RuntimeError, "previously has type", "x"): + with self.assertRaisesRegexWithHighlight(RuntimeError, "previously had type", "x"): self.checkScript(reassign_nested, (), optimize=False) def test_del(self): diff --git a/test/jit/test_optimize_for_mobile_preserve_debug_info.py b/test/jit/test_optimize_for_mobile_preserve_debug_info.py new file mode 100644 index 0000000000000..c08f3b5838fae --- /dev/null +++ b/test/jit/test_optimize_for_mobile_preserve_debug_info.py @@ -0,0 +1,261 @@ +import torch +import torch._C +import torch.backends.xnnpack +import torch.nn.functional as F +from torch.testing._internal.jit_utils import JitTestCase + +class TestOptimizeForMobilePreserveDebugInfo(JitTestCase): + def check_replacement( + self, + model, + replacements, + jit_pass, + ): + """ + model: Model which optimization is performed on + replacements: Dict mapping from nodes' kinds in the optimized model + to the kinds of nodes they replaced in the original model + jit_pass: Function to perform optimization + """ + + original_kinds = set(replacements.values()) + original_source_ranges = { + node.kind(): node.sourceRange() + for node in model.graph.nodes() + if node.kind() in original_kinds + } + + jit_pass(model._c) + + for node in model.graph.nodes(): + if node.kind() in replacements: + self.assertEqual( + node.sourceRange(), + original_source_ranges[replacements[node.kind()]], + ) + + def test_replace_conv1d_with_conv2d(self): + class TestConv1d(torch.nn.Module): + def __init__(self, weight, bias): + super(TestConv1d, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, x): + return F.conv1d(x, self.weight, self.bias) + + self.check_replacement( + model=torch.jit.script( + TestConv1d( + weight=torch.rand(3, 3, 3), + bias=torch.rand(3), + ), + ), + replacements={ + "prim::ListUnpack": "aten::conv1d", + "prim::ListConstruct": "aten::conv1d", + "aten::unsqueeze": "aten::conv1d", + "aten::conv2d": "aten::conv1d", + "aten::squeeze": "aten::conv1d", + }, + jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d, + ) + + def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self): + class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module): + def __init__( + self, + linear_weight, + linear_bias, + conv2d_weight, + conv2d_bias, + conv_transpose2d_weight, + conv_transpose2d_bias, + ): + super( + TestPrepackedLinearBeforeInlineAndConv2dOp, + self, + ).__init__() + self.linear_weight = linear_weight.float() + self.linear_bias = linear_bias.float() + self.conv2d_weight = conv2d_weight.float() + self.conv2d_bias = conv2d_bias.float() + self.conv_transpose2d_weight = conv_transpose2d_weight.float() + self.conv_transpose2d_bias = conv_transpose2d_bias.float() + + def forward(self, x): + linear_res = F.linear( + x.float(), + self.linear_weight, + self.linear_bias, + ) + conv2d_res = F.conv2d( + input=linear_res.unsqueeze(dim=0).float(), + weight=self.conv2d_weight, + bias=self.conv2d_bias, + ) + return F.conv_transpose2d( + input=conv2d_res, + weight=self.conv_transpose2d_weight, + bias=self.conv_transpose2d_bias, + ) + + minibatch = 1 + in_channels = 6 + iH = 4 + iW = 5 + out_channels = 6 + kH = 2 + kW = 3 + + self.check_replacement( + model=torch.jit.script( + TestPrepackedLinearBeforeInlineAndConv2dOp( + linear_weight=torch.rand(iW, 3), + linear_bias=torch.rand(iW), + conv2d_weight=torch.rand(out_channels, in_channels, kH, kW), + conv2d_bias=torch.rand(out_channels), + conv_transpose2d_weight=torch.rand( + out_channels, + in_channels, + kH, + kW, + ), + conv_transpose2d_bias=torch.rand(out_channels), + ), + ), + replacements={ + "prepacked::linear_clamp_prepack": "prim::CallFunction", + "prepacked::linear_clamp_run": "prim::CallFunction", + "prepacked::conv2d_clamp_prepack": "aten::conv2d", + "prepacked::conv2d_clamp_run": "aten::conv2d", + "prepacked::conv2d_transpose_clamp_prepack": + "aten::conv_transpose2d", + "prepacked::conv2d_transpose_clamp_run": + "aten::conv_transpose2d", + }, + jit_pass=torch._C._jit_pass_insert_prepacked_ops, + ) + + def test_insert_pre_packed_linear_op(self): + self.check_replacement( + model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)), + replacements={ + "prepacked::linear_clamp_prepack": "aten::linear", + "prepacked::linear_clamp_run": "aten::linear" + }, + jit_pass=torch._C._jit_pass_insert_prepacked_ops, + ) + + def run_test_fuse_activation_with_pack_ops_linear_conv2d( + self, + linear_activation, + linear_activation_kind, + conv2d_activation, + conv2d_activation_kind, + ): + class TestFuseActivationLinearConv2d(torch.nn.Module): + def __init__( + self, + linear_weight, + linear_bias, + conv2d_weight, + conv2d_bias, + ): + super(TestFuseActivationLinearConv2d, self).__init__() + self.linear_weight = linear_weight + self.linear_bias = linear_bias + self.conv2d_weight = conv2d_weight + self.conv2d_bias = conv2d_bias + + def forward(self, x): + x = F.linear( + input=x, + weight=self.linear_weight, + bias=self.linear_bias, + ) + x = linear_activation(x) + x = F.conv2d( + input=x.unsqueeze(dim=0), + weight=self.conv2d_weight, + bias=self.conv2d_bias, + ) + return conv2d_activation(x) + + linear_in_features = 5 + linear_out_features = 4 + conv2d_in_channels = 3 + conv2d_out_channels = 4 + conv2d_kernel = 2 + x_shape = (3, 2, 5) + + model = torch.jit.trace( + TestFuseActivationLinearConv2d( + linear_weight=torch.nn.Parameter( + data=torch.rand( + linear_out_features, + linear_in_features, + ), + requires_grad=False, + ), + linear_bias=torch.nn.Parameter( + data=torch.rand(linear_out_features), + requires_grad=False, + ), + conv2d_weight=torch.rand( + conv2d_out_channels, + conv2d_in_channels, + conv2d_kernel, + conv2d_kernel, + ), + conv2d_bias=torch.rand(conv2d_out_channels), + ), + torch.rand(x_shape), + ) + + torch._C._jit_pass_insert_prepacked_ops(model._c) + + self.check_replacement( + model=model, + replacements={ + "prepacked::linear_clamp_prepack": + "prepacked::linear_clamp_prepack", + "prepacked::linear_clamp_run": linear_activation_kind, + "prepacked::conv2d_clamp_prepack": + "prepacked::conv2d_clamp_prepack", + "prepacked::conv2d_clamp_run": conv2d_activation_kind, + }, + jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv, + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_1(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.hardtanh, + linear_activation_kind="aten::hardtanh", + conv2d_activation=F.hardtanh_, + conv2d_activation_kind="aten::hardtanh_" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_2(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.hardtanh_, + linear_activation_kind="aten::hardtanh_", + conv2d_activation=F.hardtanh, + conv2d_activation_kind="aten::hardtanh" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_3(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.relu, + linear_activation_kind="aten::relu", + conv2d_activation=F.relu_, + conv2d_activation_kind="aten::relu_" + ) + + def test_fuse_activation_with_pack_ops_linear_conv2d_4(self): + self.run_test_fuse_activation_with_pack_ops_linear_conv2d( + linear_activation=F.relu_, + linear_activation_kind="aten::relu_", + conv2d_activation=F.relu, + conv2d_activation_kind="aten::relu" + ) diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index b04a66e5dfcd9..468eb2787814b 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -40,7 +40,7 @@ def forward(self, x) -> Any: make_global(TestPDTModel) pdt_model = TestPDTModel() inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ] - scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp}) + scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) self.assertEqual(scripted_pdt_model(50), pdt_model(50)) self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) self.assertTrue(scripted_pdt_model(True), pdt_model(True)) @@ -67,7 +67,7 @@ def forward(self, x): inner_pdt_model = NestedPDTInner() wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) inp: List[Tuple[Any, ...]] = [(20, ), (False, )] - scripted_pdt_model = torch.jit._script_pdt(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}) + scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}) self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) @@ -95,8 +95,8 @@ def forward(self, x): outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ] outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )] - scripted_pdt_model = torch.jit._script_pdt(outer_pdt_model, example_inputs={inner_pdt_model: inner_input, - outer_pdt_model: outer_input, }) + scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input, + outer_pdt_model: outer_input, }) self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) @@ -119,7 +119,7 @@ def fun(self, x): make_global(NestedFunctionInForward) pdt_model = NestedFunctionInForward() inp: List[Tuple[Any, ...]] = [(-1, ), (False, )] - scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: inp}) + scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) self.assertEqual(scripted_pdt_model(30), pdt_model(30)) self.assertEqual(scripted_pdt_model(True), pdt_model(True)) @@ -142,7 +142,7 @@ def fn(self, x, y) -> Any: make_global(TestModelWithExport) pdt_model = TestModelWithExport() inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ] - scripted_pdt_model = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model.fn: inp}) + scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp}) self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)) @@ -155,7 +155,7 @@ def test_sum(self, a): make_global(PDTModel) pdt_model = PDTModel() inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ] - scripted_pdt_model = torch.jit._script_pdt(PDTModel, example_inputs={pdt_model.test_sum: inp}) + scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp}) script_model = scripted_pdt_model() self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], )) @@ -174,8 +174,8 @@ def test_substring(self, a, b): pdt_model = PDTModelWithManyMethods() list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ] str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ] - scripted_pdt_model = torch.jit._script_pdt(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp, - pdt_model.test_substring: str_inp}) + scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp, + pdt_model.test_substring: str_inp}) script_model = scripted_pdt_model() self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], )) self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", )) @@ -195,8 +195,8 @@ def test_find(self, a, b): pdt_model_two = PDTModelTwo() dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ] list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ] - scripted_pdt_model_one = torch.jit._script_pdt(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}) - scripted_pdt_model_two = torch.jit._script_pdt(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}) + scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}) + scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}) script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two() self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4), @@ -209,28 +209,28 @@ def test_sum(a, b): return a + b make_global(test_sum) - scripted_fn_add = torch.jit._script_pdt(test_sum, example_inputs=[(3, 4)]) + scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)]) self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2)) def test_sub(a, b): return a - b make_global(test_sub) - scripted_fn_sub = torch.jit._script_pdt(test_sub, example_inputs=[(3.9, 4.10)]) + scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)]) self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9)) def test_mul(a, b): return a * b make_global(test_mul) - scripted_fn_mul = torch.jit._script_pdt(test_mul, example_inputs=[(-10, 9)]) + scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)]) self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3)) def test_args_complex(real, img): return torch.complex(real, img) make_global(test_args_complex) - scripted_fn_complex = torch.jit._script_pdt(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]) + scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]) arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) @@ -241,7 +241,7 @@ def test_bool(a): return 0 make_global(test_bool) - scripted_fn_bool = torch.jit._script_pdt(test_bool, example_inputs=[(True,)]) + scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)]) self.assertEqual(scripted_fn_bool(True), test_bool(True)) def test_str(a): @@ -251,7 +251,7 @@ def test_str(a): return True make_global(test_str) - scripted_fn_str = torch.jit._script_pdt(test_str, example_inputs=[("",)]) + scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)]) self.assertEqual(scripted_fn_str("abc"), test_str("abc")) def test_pdt_list_and_tuple(self): @@ -260,24 +260,24 @@ def test_list_and_tuple(a): make_global(test_list_and_tuple) - scripted_fn_float_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([4.9, 8.9],)]) + scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)]) self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])) - scripted_fn_bool_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([True, False, True],)]) + scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)]) self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True])) - scripted_fn_int_list_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[([3, 4, 5], )]) + scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )]) self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])) - scripted_fn_float_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((4.9, 8.9),)]) + scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)]) self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))) - scripted_fn_bool_tuple_input = torch.jit._script_pdt(test_list_and_tuple, - example_inputs=[((True, False, True),)]) + scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple, + example_inputs=[((True, False, True),)]) self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)), test_list_and_tuple((True, True, True))) - scripted_fn_int_tuple_input = torch.jit._script_pdt(test_list_and_tuple, example_inputs=[((3, 4, 5), )]) + scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )]) self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))) def test_nested_list_and_tuple(self): @@ -295,22 +295,22 @@ def test_nested_tuple(inp): make_global(test_nested_list, test_nested_tuple) list_inp = [[1, 2, 3, ], [5, 6, 7, ]] - scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ]) + scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]] self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) list_inp = ([1, 2, 3, ], [5, 6, 7, ]) - scripted_fn = torch.jit._script_pdt(test_nested_list, example_inputs=[(list_inp, ), ]) + scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ]) self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )] - scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ]) + scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )] self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) tup_inp = ((True, False, True, ), (False, False, False, )) - scripted_fn = torch.jit._script_pdt(test_nested_tuple, example_inputs=[(tup_inp, ), ]) + scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) inp = ((True, True, True, ), (False, False, True, )) self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) @@ -324,11 +324,11 @@ def test_dict_int_list(a): make_global(test_dict, test_dict_int_list) str_bool_inp = {'foo' : True, 'bar': False} - scripted_fn = torch.jit._script_pdt(test_dict, example_inputs=[(str_bool_inp,)]) + scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, )) str_list_inp = {0 : [True, False], 1: [False, True]} - scripted_fn = torch.jit._script_pdt(test_dict_int_list, example_inputs=[(str_list_inp,)]) + scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)]) self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ), test_dict_int_list({0 : [False, False], 1: [True, True]}, )) @@ -349,14 +349,14 @@ def test_multiple_type_refinement(a): make_global(test_multiple_types, test_multiple_type_refinement) - scripted_fn = torch.jit._script_pdt(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )]) + scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )]) self.assertEqual(scripted_fn(10), test_multiple_types(10)) self.assertEqual(scripted_fn("def"), test_multiple_types("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) - scripted_fn = torch.jit._script_pdt(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,), - ([3, 4, 5],), (True, ), ({"a": True}, ), ]) + scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,), + ([3, 4, 5],), (True, ), ({"a": True}, ), ]) self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) @@ -381,7 +381,7 @@ def test_model(a, m): make_global(UserDefinedClass, test_model) user_class = UserDefinedClass() - scripted_fn = torch.jit._script_pdt(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) + scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class)) self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class)) @@ -403,7 +403,7 @@ def test_model_with_args(a, m): make_global(ClassWithArgs, test_model_with_args) user_class = ClassWithArgs(False) - scripted_fn = torch.jit._script_pdt(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) + scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True))) def test_nn_parameter_as_arg(self): @@ -420,7 +420,7 @@ def forward(self, y): make_global(TestNNParameter) pdt_model = TestNNParameter() - scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [(10, ), ], }) + scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], }) self.assertEqual(scripted_fn(20), pdt_model(20)) def test_fx_tracing_with_typing(self): @@ -434,7 +434,7 @@ def forward(self, a) -> FXModelOutput: make_global(FXModel, FXModelOutput) pdt_model = FXModel() - scripted_fn = torch.jit._script_pdt(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) + scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) self.assertEqual(scripted_fn([20]), pdt_model([20])) def test_nonetype_as_optional_of_type(self): @@ -446,52 +446,11 @@ def test_none(a) -> Any: make_global(test_none) - scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10.6, )]) + scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )]) self.assertEqual(scripted_fn(30.9, ), test_none(30.9, )) - scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (10, )]) + scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )]) self.assertEqual(scripted_fn(2, ), test_none(2, )) - scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )]) + scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )]) self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), )) - - class TestForwardWithNoneType(torch.nn.Module): - def forward(self, a): - count = 0 - for i, val in enumerate(a): - if val is None: - count += 1 - return count - - make_global(TestForwardWithNoneType) - pdt_model = TestForwardWithNoneType() - - # Test List[Optional[float]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2.9, ], )]) - self.assertEqual(scripted_model([2.8, 6.7, 3.8, None, ]), pdt_model([2.8, 6.7, 3.8, None, ])) - - # Test Tuple[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((5.1, ), ), ((None, ), ), ]) - self.assertEqual(scripted_model((6.2, None, 10.6, 80.1, None, )), pdt_model((6.2, None, 10.6, 80.1, None, ))) - - # Test List[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2, ], )]) - self.assertEqual(scripted_model([2, None, 6, 8, ]), pdt_model([2, None, 6, 8, ])) - - # Test Tuple[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )]) - self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, ))) - - # Test Tuple[Optional[float]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )]) - self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, ))) - - # Test Tuple[Optional[torch.Tensor]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[(((torch.ones(1), ), (None, ), ), )]) - self.assertEqual(scripted_model((torch.ones(1), torch.ones(1), None)), - pdt_model((torch.ones(1), torch.ones(1), None))) - - # Test List[Optional[torch.Tensor]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([torch.ones(1), ], )]) - self.assertEqual(scripted_model([torch.ones(1), torch.ones(1), None]), - pdt_model([torch.ones(1), torch.ones(1), None])) diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py index aa8be0518385f..b9ed9d0b78eb5 100644 --- a/test/jit/test_profiler.py +++ b/test/jit/test_profiler.py @@ -29,8 +29,6 @@ def setUp(self): torch._C._debug_set_fusion_group_inlining(False) self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() torch._C._jit_set_te_must_use_llvm_cpu(False) - self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled() - torch._C._jit_set_texpr_parallel_cpu_enabled(True) def tearDown(self): torch._C._jit_set_profiling_executor(self.prev_exec) @@ -42,7 +40,6 @@ def tearDown(self): torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled) torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) - torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel) def test_tensor_type_not_determined_by_inputs(self): @torch.jit.script diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 33dc515c51145..6d4e33cda852f 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -3,7 +3,6 @@ import operator from torch.testing import FileCheck -from typing import List if __name__ == '__main__': @@ -60,15 +59,6 @@ def prop_shapes_on_graph(inp0, inp1): self.assertEqual(output_shape[1], sym2) self.assertEqual(output_shape[2], sym3) - def test_sharing_of_list_len(self): - @torch.jit.script - def foo(x, out: List[int]): - return torch.nn.functional.adaptive_avg_pool2d(x, out) - - self.run_pass("inline", foo.graph) - torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) - FileCheck().check("Tensor(*, *)").check_same("adaptive_avg_pool2d").run(foo.graph) - def test_shared_shape_graph(self): @torch.jit.script def foo(x, y): @@ -165,3 +155,25 @@ def foo2(x, y): inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1])) torch._C._jit_pass_propagate_shapes_on_graph(graph) self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]) + + def test_adaptive_avg_pool2d(self): + inps = [ + [(1, 64, 8, 9), (5, 7)], + [(1, 64, 10, 9), (7)], + [(1, 64, 10, 9), (5, None)], + [(1, 8, 4, 3), (None, None)], + [(1, 8, 4, 3), (None, 5)], + ] + + for inp in inps: + t = torch.randn(*inp[0]) + out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size() + + def foo(x): + return torch.nn.functional.adaptive_avg_pool2d(x, inp[1]) + + fn = torch.jit.trace(foo, (t,)) + torch._C._jit_erase_non_input_shape_information(fn.graph) + torch._C._jit_pass_peephole(fn.graph) + torch._C._jit_pass_constant_propagation(fn.graph) + self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 247072fb3e94d..1d95dc8d0d8a4 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -163,13 +163,13 @@ def forward(self, x, y): eager_out = mod(*test_inputs) traced_out = traced_func(*test_inputs) self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here") - self.assertTrue(torch.allclose(eager_out, traced_out)) + self.assertEqual(eager_out, traced_out) test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12)) eager_out = mod(*test_inputs) traced_out = traced_func(*test_inputs) self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here") - self.assertTrue(torch.allclose(eager_out, traced_out)) + self.assertEqual(eager_out, traced_out) def test_typeas_trace_check(self): diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index f60f25f782e95..125197c87bbb1 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -92,10 +92,9 @@ def fn(x): graph = torch.jit.script(fn).graph - print(graph) - # Check that we're making a `List[Tuple[str, Any]]` - FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])" + "[] = prim::ListConstruct()").run(graph) def test_list_type_refinement_defaults_to_Any_list_comprehension(self): def fn(x): @@ -116,10 +115,9 @@ def fn(x): graph = torch.jit.script(fn).graph - print(graph) - # Check that we're making a `List[Tuple[str, Any]]` - FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])" + "[] = prim::ListConstruct()").run(graph) def test_list_type_refinement_annotation_element_mismatch(self): def fn(): @@ -145,7 +143,8 @@ def fn(x): graph = torch.jit.script(fn).graph - FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph) + FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])" + " = prim::DictConstruct").run(graph) def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self): def fn(x): @@ -161,7 +160,8 @@ def fn(x): graph = torch.jit.script(fn).graph - FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph) + FileCheck().check("Dict(str, Union[Tensor, Dict(str, Tensor)])" + " = prim::DictConstruct").run(graph) def test_dict_type_refinement_annotation_key_mismatch(self): def fn(): diff --git a/test/jit/test_union.py b/test/jit/test_union.py new file mode 100644 index 0000000000000..df909a6e8100f --- /dev/null +++ b/test/jit/test_union.py @@ -0,0 +1,657 @@ +import io +import os +import sys + +import torch +from torch.testing import FileCheck +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase, make_global + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class TestUnion(JitTestCase): + """ + This class tests the functionality of `Union`. + + Note: It's important to be able to refine the type of a `Union` to + one of its internal types. Currently, there are differences in the + way Python expects `isinstance` checks and the way TorchScript + expects `isinstance` checks. This means that we can't use + `checkScript` in our test cases because either the eager mode or the + script mode wouldn't run! So, some test cases have separate but + equivalent functions to emulate `checkScript`. + """ + + def test_union_with_scalar_values(self): + def fn(x: Union[int, float]) -> str: + return "foo" + + self.checkScript(fn, (1,)) + self.checkScript(fn, (1.0,)) + + scripted = torch.jit.script(fn) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[float, int\] but " + "instead found type str"): + scripted("1") + + def test_union_with_collections(self): + def fn(x: Union[Dict[str, int], List[int]]) -> str: + return "foo" + + self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) + self.checkScript(fn, ([1, 2, 3],)) + + scripted = torch.jit.script(fn) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[List\[int\], Dict\[str, " + r"int\]\] but instead found type " + r"Dict\[str, str\]"): + scripted({"foo": "bar", "baz": "qux"}) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[List\[int\], Dict\[str, " + r"int\]\] but instead found type " + r"List\[str\]"): + scripted(["foo", "bar", "baz"]) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[List\[int\], Dict\[str, " + r"int\]\] but instead found type " + "str"): + scripted("1") + + def test_union_with_enum(self): + class Color(Enum): + RED = 1 + GREEN = 2 + + make_global(Color) + + def fn(x: Union[str, Color]) -> str: + return "foo" + + self.checkScript(fn, (Color.RED,)) + self.checkScript(fn, ("red",)) + + scripted = torch.jit.script(fn) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[__torch__.jit.test_union." + r"Color, str\] but instead found " + "type int"): + scripted(1) + + def test_union_in_class_constructor(self): + + @torch.jit.script + class A(object): # noqa: B903 + def __init__(self, x: Union[int, str]) -> None: + self.x = x + + def fn(x: Union[str, int]) -> A: + return A(x) + + self.assertEqual(fn("foo").x, "foo") + self.assertEqual(fn(1).x, 1) + + scripted = torch.jit.script(fn) + + with self.assertRaisesRegex(RuntimeError, "Expected a member of" + r" Union\[int, str\] but instead " + r"found type List\[str\]"): + scripted(["foo", "bar", "baz"]) + + def test_union_return_type(self): + def fn(x: int) -> Union[int, str]: + return "foo" + + self.checkScript(fn, (1,)) + + def test_union_as_annotation(self): + def fn() -> Union[int, str]: + x: Union[int, str] = "foo" + return x + + self.checkScript(fn, ()) + + def test_union_as_annotation_in_typed_container(self): + def fn() -> None: + l: List[Union[int, str]] = [] + u1: Union[int, str] = "foo" + u2: Union[int, str] = 1 + l.append(u1) + l.append(u2) + + self.checkScript(fn, ()) + + def test_union_as_annotation_py2(self): + def fn(): + # type: () -> Union[int, str] + x: Union[int, str] = "foo" + return x + + self.checkScript(fn, ()) + + def test_union_as_internal_tuple_type(self): + def fn(): + t: Tuple[Union[int, str], Union[int, str]] = (1, "foo") + return t + + self.checkScript(fn, ()) + + def test_union_variable_can_be_reassigned(self): + @torch.jit.script + def aux1(i: int): + return int(i ** 2) + + @torch.jit.script + def aux2(s: str): + return s + s + + def fn() -> Union[int, str]: + x: Union[int, str] = "foo" + i: int = 1 + x = i + y: int = aux1(x) + z: str = aux2(str(y)) + x = z + return x + + self.checkScript(fn, ()) + + def test_union_does_not_replace_existing_annotated_type(self): + def fn(): + x: List[int] = [1, 2, 3] + x.append("foo") + return x + + with self.assertRaisesRegex(RuntimeError, "Could not match type str"): + scripted = torch.jit.script(fn) + scripted() + + def test_union_does_not_replace_existing_annotated_type_union(self): + def fn(): + x: List[Union[int, str]] = [1, "foo", 3] + x.append(2.0) + return x + + with self.assertRaisesRegex(RuntimeError, "Could not match type float"): + scripted = torch.jit.script(fn) + scripted() + + def test_union_does_not_replace_existing_annotated_type_empty_container(self): + def fn(): + x: List[int] = [] + x.append("foo") + return x + + with self.assertRaisesRegex(RuntimeError, "Could not match type str"): + scripted = torch.jit.script(fn) + scripted() + + def test_unions_of_unions_are_flattened(self): + @torch.jit.script + def fn(x: Union[Union[int, str], float]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : Union[float, int, str]") \ + .run(s) + + def test_unions_of_a_single_argument_vanish(self): + @torch.jit.script + def fn(x: Union[int]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : int") \ + .run(s) + + def test_union_redundant_arguments_are_skipped(self): + @torch.jit.script + def fn(x: Union[int, str, int]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : Union[int, str]") \ + .run(s) + + def test_union_redundant_arguments_are_skipped_optional(self): + @torch.jit.script + def fn(x: Union[int, Optional[float], Optional[int]]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : Union[float, int, NoneType]") \ + .run(s) + + def test_union_redundant_arguments_are_skipped_subtyping(self): + @torch.jit.script + def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : Union[(int?, int), str]") \ + .run(s) + + def test_union_redundant_arguments_are_skipped_container(self): + @torch.jit.script + def fn(x: Union[List[str], List[float], List[str]]) -> str: + return "foo" + + s = fn.graph + + FileCheck().check("x : Union[float[], str[]]") \ + .run(s) + + def test_union_argument_order_is_ignored(self): + @torch.jit.script + def fn1(x: Union[int, str]) -> str: + return "foo" + + @torch.jit.script + def fn2(x: Union[str, int]) -> str: + return "foo" + + for s in (fn1.graph, fn2.graph): + FileCheck().check("x : Union[int, str]") \ + .run(s) + + def test_union_argument_order_is_ignored_container(self): + @torch.jit.script + def fn1(x: Union[List[str], List[int]]) -> str: + return "foo" + + @torch.jit.script + def fn2(x: Union[List[int], List[str]]) -> str: + return "foo" + + for s in (fn1.graph, fn2.graph): + FileCheck().check("x : Union[int[], str[]]") \ + .run(s) + + def test_union_T_None_is_equivalent_to_optional_T(self): + @torch.jit.script + def inner(x: Union[int, None]) -> int: + if x is not None: + return x + else: + return 5 + + @torch.jit.script + def fn1() -> int: + a: Optional[int] = 5 + b: Optional[int] = None + a_ = inner(a) + b_ = inner(b) + return a_ + b_ + + self.assertEqual(fn1(), 10) + + @torch.jit.script + def inner2(x: Optional[int]) -> int: + if x is not None: + return x + else: + return 5 + + @torch.jit.script + def fn2() -> int: + a: Union[int, None] = 5 + b: Union[int, None] = None + a_ = inner(a) + b_ = inner(b) + return a_ + b_ + + self.assertEqual(fn2(), 10) + + def test_union_optional_of_union_is_flattened(self): + @torch.jit.script + def fn(flag: int) -> Union[str, int, None]: + y: Union[int, str, None] = "foo" + if flag == 0: + x: Optional[Union[int, str]] = y + elif flag == 1: + x: Optional[Union[int, str]] = 1 + else: + x: Optional[Union[int, str]] = None + return x + + # Can't use `checkScript` because it will flag the fact that + # the original code has `Optional[Union[int, str]]` but the + # saved/loaded code has `Union[int, NoneType, str]` (even + # though this is exactly what we want) + self.assertEqual(fn(0), "foo") + self.assertEqual(fn(1), 1) + self.assertEqual(fn(2), None) + + buffer = io.BytesIO() + torch.jit.save(fn, buffer) + buffer = io.BytesIO(buffer.getvalue()) + l = torch.jit.load(buffer) + + s = l.code + + FileCheck().check("Union[int, NoneType, str]") \ + .check("Union[int, NoneType, str]") \ + .run(s) + + def test_union_subclasses_larger_union(self): + def fn() -> Union[int, str, torch.Tensor]: + x: Union[int, str] = "foo" + return x + + self.checkScript(fn, ()) + + # TODO: We would like to eventually support this. The issue is being + # tracked at https://github.com/pytorch/pytorch/issues/58167 + def test_union_as_dict_key(self): + def fn(): + x: Dict[Union[int, str], str] = {} + x["foo"] = "bar" + x[1] = 2 + return x[1] + + with self.assertRaisesRegex(RuntimeError, "only int, float, " + "complex, Tensor and string keys " + "are supported"): + torch.jit.script(fn) + + def test_union_as_dict_value(self): + def fn(): + x: Dict[str, Union[int, str]] = {} + x["foo"] = "bar" + x["baz"] = 2 + return x["baz"] + + self.checkScript(fn, ()) + + def test_union_module_with_union_instance_variable(self): + class M(torch.nn.Module): + + x: Union[int, str] + + def __init__(self, x: Union[int, str]): + super().__init__() + self.x: Union[int, str] = x + + def forward(self, y: Union[int, str]): + self.x = y + return self.x + + self.checkModule(M(2,), (1,)) + self.checkModule(M("bar"), ("foo",)) + + def test_union_module_with_union_class_variable(self): + class M(torch.nn.Module): + x: Union[int, str] = "foo" + + def __init__(self, y: int): + super().__init__() + x = y + + def forward(self, z: str): + x = z + return x + + self.checkModule(M(1), ("foo",)) + + def test_union_type_refinement(self): + def fn(x: Union[int, str]) -> str: + if isinstance(x, str): + z = x + "bar" + return x + else: + return "baz" + + self.checkScript(fn, ("foo",)) + self.checkScript(fn, (1,)) + + def test_union_type_refinement_union_rhs(self): + def fn(x: int) -> str: + if torch.jit.isinstance(x, Union[int, str]): + return "bar" + else: + return "baz" + + self.checkScript(fn, (1,)) + + def test_union_type_refinement_tuple_rhs(self): + def fn(x: Union[int, float, List[str]]) -> str: + if isinstance(x, (int, float)): + if isinstance(x, int): + return str(x) + else: + return "foo" + else: + if len(x): + return x[0] + else: + return "bar" + + self.checkScript(fn, (1,)) + self.checkScript(fn, (1.0,)) + self.checkScript(fn, (["a", "b", "c"],)) + + def test_union_type_refinement_tuple_rhs_noncontained_type(self): + def fn(x: Union[int, List[str]]) -> str: + if isinstance(x, (int, float)): + y = x + x + return str(y) + else: + if len(x): + return x[0] + else: + return "bar" + + self.checkScript(fn, (1,)) + self.checkScript(fn, (["a", "b", "c"],)) + + def test_union_type_refinement_tuple_rhs_union(self): + @torch.jit.script + def fn(x: int) -> str: + if torch.jit.isinstance(x, (Union[int, str], float)): + y = x + x + return str(y) + else: + return "foo" + + # TODO: There's currently an unrelated bug in + # `torch.jit.isinstance` that makes it fail for tuple literals. + # Posted here: https://github.com/pytorch/pytorch/issues/60095 + # Change `assertEqual` to `checkScript` when the bug is fixed + self.assertEqual(fn(1), "2") + + def test_union_type_refinement_statically_false(self): + @torch.jit.script + def fn(x: int) -> str: + if torch.jit.isinstance(x, (Union[str, float], List[str], str)): + z = x + "foo" + return z + else: + return "bar" + + s = fn.graph + + # Check that we don't have any branching statements + FileCheck().check_not("block0()") \ + .check_not("block1()") \ + .run(s) + + def test_union_type_refinement_statically_true(self): + @torch.jit.script + def fn(x: Union[List[int], int]) -> Union[List[int], int]: + if not torch.jit.isinstance(x, (int, List[int])): + return x + else: + l = [1, 2, 3] + y: Union[List[int], int] = l + return y + + s = fn.graph + + # Check that we don't have any branching statements + FileCheck().check_not("block0()") \ + .check_not("block1()") \ + .run(s) + + def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): + def fn(x: Union[List[int], int]) -> int: + if torch.jit.isinstance(x, (int, float, str)): + # We should know that `x` is an `int` here + z = x + 1 + return z + else: + return 100 + + self.checkScript(fn, ([1, 2, 3],)) + self.checkScript(fn, (1,)) + + def test_union_type_refinement_partial_static_refinement_union_rhs(self): + def fn(x: Union[List[int], int]) -> int: + if torch.jit.isinstance(x, Union[int, float, str]): + # We should know that `x` is an `int` here + z = x + 1 + return z + else: + return 100 + + self.checkScript(fn, ([1, 2, 3],)) + self.checkScript(fn, (1,)) + + def test_union_type_refinement_internal_declaration(self): + def fn(flag: bool) -> str: + x: Union[int, str, None] = None + if (flag): + y = "foo" + else: + y = 1 + if isinstance(x, str): + return x + else: + return "bar" + + self.checkScript(fn, (True,)) + self.checkScript(fn, (False,)) + + def test_union_branching_with_union_return_and_homogenous_types(self): + def fn(x: int) -> Union[int, str]: + if x % 2: + return "foo" + else: + return "bar" + + self.checkScript(fn, (1,)) + self.checkScript(fn, (8,)) + + def test_union_branching_does_not_autoinfer_undeclared_union(self): + def fn(x: int) -> str: + if x % 2: + y = "foo" + else: + y = x + if isinstance(y, str): + return y + else: + return "bar" + + with self.assertRaisesRegex(RuntimeError, "y is set to type str" + " in the true branch and type int " + "in the false branch"): + torch.jit.script(fn) + + def test_union_branching_does_not_widen_existing_inferred_type(self): + def fn(x: int) -> str: + y = "foo" + if x % 2: + y = "bar" + else: + y = x + if isinstance(y, str): + return y + else: + return "baz" + + with self.assertRaisesRegex(RuntimeError, "previously had type " + "str but is now being assigned to a" + " value of type int"): + torch.jit.script(fn) + + def test_union_schema_matching_on_internal_type(self): + def fn(x: Union[List[int], Dict[str, int]]) -> int: + if torch.jit.isinstance(x, List[int]): + return x[0] + else: + return list(x.values())[0] + + self.checkScript(fn, ([1, 2, 3],)) + self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) + + def test_union_subtractive_refinement(self): + def fn(x: Union[List[int], int]) -> int: + if not isinstance(x, int): + x.append(1) + return x[0] + else: + return x + + self.checkScript(fn, (1,)) + self.checkScript(fn, ([1, 2, 3],)) + + def test_union_subtractive_refinement_with_container(self): + def fn(x: Union[List[int], int]) -> int: + if not torch.jit.isinstance(x, List[int]): + return x + else: + x.append(1) + return x[0] + + self.checkScript(fn, (1,)) + self.checkScript(fn, ([1, 2, 3],)) + + def test_union_memory_aliasing(self): + def fn(): + x : List[torch.Tensor] = [] + z : List[Optional[List[torch.Tensor]]] = [] + z.append(x) + x_alias = z[0] + if torch.jit.isinstance(x_alias, List[torch.Tensor]): + x_alias.append(torch.tensor(3)) + return x + + self.checkScript(fn, ()) + + def test_union_serialization_preserves_type_annotations(self): + # This function will fail after being torch.jit.save'd and + # torch.jit.load'd if the type annotations aren't preserved + # for Union during serialization. We need the `Union[str, int]` + # annotation to make sure that `y` is typed as a Union instead + # of as a str in one branch and an int in the other + def fn(x: int) -> str: + if x % 2: + y: Union[str, int] = "bar" + else: + y: Union[str, int] = x + if isinstance(y, str): + return y + else: + return "baz" + + self.checkScript(fn, (1,)) + self.checkScript(fn, (8,)) diff --git a/test/mobile/test_bytecode.py b/test/mobile/test_bytecode.py index 5511e6a63b085..95baa86d5763e 100644 --- a/test/mobile/test_bytecode.py +++ b/test/mobile/test_bytecode.py @@ -228,7 +228,7 @@ def test_bytecode_values_for_all_backport_functions(self): # # Load model and run forward method # mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path)) # mobile_module_result = mobile_module(module_input) - # torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) + # torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) # current_to_version -= 1 # # Check backport failure case @@ -270,7 +270,7 @@ def test_backport_bytecode_from_file_to_file(self): module_input = 1 mobile_module_result = mobile_module(module_input) expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) - torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) + torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) shutil.rmtree(tmpdirname) # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations @@ -296,7 +296,7 @@ def test_backport_bytecode_from_file_to_buffer(self): module_input = 1 mobile_module_result = mobile_module(module_input) expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) - torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result) + torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) def test_get_model_ops_and_info(self): diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 369371fd3279c..a86669ec574b7 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -48,13 +48,13 @@ def forward(self, x): mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) - torch.testing.assert_allclose(script_module_result, mobile_module_result) + torch.testing.assert_close(script_module_result, mobile_module_result) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_allclose(script_module_result, mobile_module_forward_result) + torch.testing.assert_close(script_module_result, mobile_module_forward_result) mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result) + torch.testing.assert_close(script_module_result, mobile_module_run_method_result) def test_save_mobile_module_with_debug_info_with_trace(self): class A(torch.nn.Module): @@ -117,13 +117,13 @@ def forward(self, x): mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) - torch.testing.assert_allclose(script_module_result, mobile_module_result) + torch.testing.assert_close(script_module_result, mobile_module_result) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_allclose(script_module_result, mobile_module_forward_result) + torch.testing.assert_close(script_module_result, mobile_module_forward_result) mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result) + torch.testing.assert_close(script_module_result, mobile_module_run_method_result) def test_find_and_run_method(self): class MyTestModule(torch.nn.Module): @@ -154,7 +154,7 @@ def forward(self, arg): bundled_inputs = mobile_module.run_method("get_all_bundled_inputs") mobile_module_result = mobile_module.forward(*bundled_inputs[0]) - torch.testing.assert_allclose(script_module_result, mobile_module_result) + torch.testing.assert_close(script_module_result, mobile_module_result) def test_method_calls_with_optional_arg(self): class A(torch.nn.Module): @@ -183,7 +183,7 @@ def forward(self, x, one: int = 1): input = torch.tensor([5]) script_module_forward_result = script_module.forward(input) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_forward_result, mobile_module_forward_result ) @@ -198,7 +198,7 @@ def forward(self, x, one: int = 1): # now both match again mobile_module_forward_result = mobile_module.forward(input, 2) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_forward_result, mobile_module_forward_result ) diff --git a/test/onnx/assets/grace_hopper_517x606.jpg b/test/onnx/assets/grace_hopper_517x606.jpg new file mode 100644 index 0000000000000..d2a427810f679 Binary files /dev/null and b/test/onnx/assets/grace_hopper_517x606.jpg differ diff --git a/test/onnx/assets/rgb_pytorch.png b/test/onnx/assets/rgb_pytorch.png new file mode 100644 index 0000000000000..c9d08e6c7da91 Binary files /dev/null and b/test/onnx/assets/rgb_pytorch.png differ diff --git a/test/onnx/expect/TestOperators.test_dropout_default.expect b/test/onnx/expect/TestOperators.test_dropout_default.expect index dcbc25a55045f..550bc65f2700b 100644 --- a/test/onnx/expect/TestOperators.test_dropout_default.expect +++ b/test/onnx/expect/TestOperators.test_dropout_default.expect @@ -5,7 +5,19 @@ graph { node { input: "x" output: "1" - name: "ReduceMax_0" + output: "2" + name: "Dropout_0" + op_type: "Dropout" + attribute { + name: "ratio" + f: 0.5 + type: FLOAT + } + } + node { + input: "1" + output: "3" + name: "ReduceMax_1" op_type: "ReduceMax" attribute { name: "keepdims" @@ -31,7 +43,7 @@ graph { } } output { - name: "1" + name: "3" type { tensor_type { elem_type: 1 diff --git a/test/onnx/test_custom_ops.py b/test/onnx/test_custom_ops.py index 739f267f90a95..04ac9a0066876 100644 --- a/test/onnx/test_custom_ops.py +++ b/test/onnx/test_custom_ops.py @@ -125,5 +125,37 @@ def symbolic_pythonop(g, n, *args, **kwargs): model = MyModule() run_model_test(self, model, input=(x, )) +class TestExportAsContribOps(unittest.TestCase): + opset_version = 14 + keep_initializers_as_inputs = False + onnx_shape_inference = True + + def test_contrib_op_with_loop(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, x): + res = [] + res2 = [] + for i in range(x.size(0)): + if len(res) > 0: + res2.append(res[0]) + else: + res2.append(self.gelu(x[0])) + res.append(x[0]) + return torch.stack(res), torch.stack(res2) + + def symbolic_custom_gelu(g, input): + return g.op("com.microsoft::Gelu", input).setType(input.type()) + + from torch.onnx import register_custom_op_symbolic + register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1) + + x = torch.randn(3, 3, 4, requires_grad=True) + model = torch.jit.script(M()) + run_model_test(self, model, input=(x, )) + if __name__ == "__main__": unittest.main() diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index be7f8c62176e8..59909db5958cc 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -7,7 +7,7 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None): - opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12] + opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14] for opset_version in opset_versions: self.opset_version = opset_version diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 9fe38ca7b2455..b9e391b540663 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -681,7 +681,7 @@ def test_dropout_training(self): def test_dropout_opset12(self): x = torch.randn(3, 4, requires_grad=True) - self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12) + self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x, opset_version=12) def test_dropout_training_opset12(self): x = torch.randn(3, 4, requires_grad=True) diff --git a/test/onnx/test_pytorch_common.py b/test/onnx/test_pytorch_common.py index 0695a989013c7..09ab7a26f4967 100644 --- a/test/onnx/test_pytorch_common.py +++ b/test/onnx/test_pytorch_common.py @@ -60,6 +60,16 @@ def wrapper(self): return wrapper return skip_dec +# skips tests for all opset versions. +def skipForAllOpsetVersions(): + def skip_dec(func): + def wrapper(self): + if self.opset_version: + raise unittest.SkipTest("Skip verify test for unsupported opset_version") + return func(self) + return wrapper + return skip_dec + # Enables tests for scripting, instead of only tracing the model. def enableScriptTest(): def script_dec(func): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index b92568c03cba2..54a116b57cb1d 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -17,7 +17,7 @@ RnnModelWithPackedSequenceWithoutState) from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion, skipIfNoLapack, disableScriptTest, skipIfONNXShapeInference, - skipIfUnsupportedMaxOpsetVersion) + skipIfUnsupportedMaxOpsetVersion, skipForAllOpsetVersions) from test_pytorch_common import BATCH_SIZE from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE from typing import List, Tuple, Optional, Dict @@ -100,7 +100,10 @@ def run_model_test(self, model, batch_size=2, state_dict=None, input_names=None, output_names=None, fixed_batch_size=False, dict_check=True, training=None, remained_onnx_input_idx=None): - model.eval() + if training is not None and training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training is None or training == torch.onnx.TrainingMode.EVAL: + model.eval() if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) with torch.no_grad(): @@ -281,11 +284,14 @@ def _run_test(m, remained_onnx_input_idx): def run_model_test_with_external_data(self, model, input, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, dynamic_axes=None, input_names=None, output_names=None, - ort_optim_on=True): + ort_optim_on=True, training=None): import os import tempfile - model.eval() + if training is not None and training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training is None or training == torch.onnx.TrainingMode.EVAL: + model.eval() with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input,) @@ -490,35 +496,20 @@ def run_word_language_model(self, model_name): # Only support CPU version, since tracer is not working in GPU RNN. self.run_test(model, (x, model.hidden)) - def get_image_from_url(self, url, size=(300, 200)): + def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: import os - from urllib.parse import urlsplit - from urllib import request from PIL import Image from torchvision import transforms - from torch._utils_internal import get_writable_path - - filename = os.path.basename(urlsplit(url)[2]) - data_dir = get_writable_path(os.path.join(os.path.dirname(__file__))) - path = os.path.join(data_dir, filename) - data = request.urlopen(url, timeout=15).read() - with open(path, "wb") as f: - f.write(data) - image = Image.open(path).convert("RGB") - image = image.resize(size, Image.BILINEAR) + data_dir = os.path.join(os.path.dirname(__file__), "assets") + path = os.path.join(data_dir, *rel_path.split("/")) + image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR) - to_tensor = transforms.ToTensor() - return to_tensor(image) + return transforms.ToTensor()(image) - def get_test_images(self): - image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" - image = self.get_image_from_url(url=image_url, size=(100, 320)) - - image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" - image2 = self.get_image_from_url(url=image_url2, size=(250, 380)) - - return [image], [image2] + def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + return ([self.get_image("grace_hopper_517x606.jpg", (100, 320))], + [self.get_image("rgb_pytorch.png", (250, 380))]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # Faster RCNN model is not scriptable @@ -540,10 +531,6 @@ def test_faster_rcnn(self): dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5) def test_paste_mask_in_image(self): - # disable profiling - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - masks = torch.rand(10, 1, 26, 26) boxes = torch.rand(10, 4) boxes[:, 2:] += torch.rand(10, 2) @@ -591,10 +578,6 @@ def test_mask_rcnn(self): "scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5) def test_heatmaps_to_keypoints(self): - # disable profiling - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - maps = torch.rand(10, 1, 26, 26) rois = torch.rand(10, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints @@ -1523,7 +1506,7 @@ def list_append(boxes: List[torch.Tensor]): class Min(torch.nn.Module): def forward(self, x): - boxes = [x, x, x] + boxes = [x for _ in range(3)] return list_append(boxes) x = torch.rand(5, 5) @@ -2489,6 +2472,18 @@ def forward(self, x): x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1) self.run_test(Bernoulli(), x) + # Enable test when fix for allowzero is in ORT + @skipForAllOpsetVersions() + @skipIfUnsupportedMinOpsetVersion(14) + def test_reshape_allowzero(self): + class ReshapeModel(torch.nn.Module): + def forward(self, x): + x = x.reshape(3, 4, 0) + return x + + x = torch.randn(0, 3, 4) + self.run_test(ReshapeModel(), x) + def test_reshape_different_rank(self): class ReshapeModel(torch.nn.Module): def forward(self, x): @@ -3723,6 +3718,22 @@ def forward(self, x, h0, c0): c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) self.run_test(LSTMModel(), (input, h0, c0)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_lstm_cell(self): + class LSTMCellModel(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.lstm_cell = torch.nn.LSTMCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, bias=bias) + + def forward(self, x, h0, c0): + return self.lstm_cell(x, (h0, c0)) + + input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) + h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) + c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) + for bias in [True, False]: + self.run_test(LSTMCellModel(bias), (input, h0, c0)) + @skipIfUnsupportedMinOpsetVersion(9) def test_lstm_default_init_state(self): class LSTMModel(torch.nn.Module): @@ -4289,7 +4300,7 @@ def forward(self, x): x = torch.tensor([[1, 2], [3, 4]]) self.run_test(RepeatsDimsModel2(), (x,)) - @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedMinOpsetVersion(13) def test_dynamic_repeat_interleave(self): class SingleDynamicModel(torch.nn.Module): def forward(self, x): @@ -4311,25 +4322,62 @@ def forward(self, x): self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x], input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}}) - class SingleDynamicModel2(torch.nn.Module): + class SingleDynamicModelFloat(torch.nn.Module): def forward(self, x): repeats = torch.tensor([4]) return torch.repeat_interleave(x, repeats, dim=0) - x = torch.tensor([[1, 2], [3, 4]]) - another_x = torch.tensor([[7, 8], [5, 6]]) - self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x], + x = torch.tensor([[1.1, 2.1], [3.1, 4.1]]) + another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]]) + self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x], input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}}) - class AllDynamicModel(torch.nn.Module): - def forward(self, x): - repeats = torch.tensor([4]) - return torch.repeat_interleave(x, repeats, dim=0) + class DynamicRepeatsModel(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) - x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]]) + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) another_x = torch.tensor([[7, 8], [5, 6]]) - self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x], - input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}}) + repeats = torch.tensor([2]) + another_repeats = torch.tensor([4]) + self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}}) + + class DynamicRepeatsModel2(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2]) + another_repeats = torch.tensor([4]) + self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) + + @skipIfUnsupportedMinOpsetVersion(13) + def test_multiple_dynamic_repeat_interleave(self): + class DynamicRepeatsModel(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2, 3, 4]) + another_repeats = torch.tensor([4, 3, 2]) + self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) + + class DynamicRepeatsModel2(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=0) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2, 3]) + another_repeats = torch.tensor([4, 3]) + self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) def test_view(self): class ViewModel(torch.nn.Module): @@ -5651,6 +5699,27 @@ def forward(self, input, other): y = torch.randint(10, (5, )) self.run_test(MatmulModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) # MatMul long inputs is added in ONNX opset 9. + def test_dot(self): + class MatmulModel(torch.nn.Module): + def forward(self, input, other): + return torch.dot(input, other) + + x = torch.randn(5, requires_grad=True) + y = torch.randn(5, requires_grad=True) + self.run_test(MatmulModel(), (x, y)) + + x = torch.randint(10, (5, )) + y = torch.randint(10, (5, )) + self.run_test(MatmulModel(), (x, y)) + + @disableScriptTest() # SpectralNorm not TorchScript compatible. + def test_spectral_norm(self): + m = torch.nn.utils.spectral_norm(torch.nn.Linear(2, 4)) + + x = torch.randn(6, 2) + self.run_test(m, (x, )) + def test_prelu(self): class PReluModel(torch.nn.Module): def __init__(self): @@ -5693,6 +5762,52 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(SiLUModel(), (x)) + @skipIfUnsupportedMinOpsetVersion(14) + def test_tril(self): + class trilModel(torch.nn.Module): + def forward(self, x): + return torch.tril(x) + + x = torch.randn(2, 3, 4) + self.run_test(trilModel(), (x)) + + class trilModelwithDiagonal(torch.nn.Module): + def forward(self, x): + return torch.tril(x, diagonal=1) + + x = torch.randn(2, 3, 4) + self.run_test(trilModelwithDiagonal(), (x)) + + class trilModelwithNegDiagonal(torch.nn.Module): + def forward(self, x): + return torch.tril(x, diagonal=-1) + + x = torch.randn(2, 3, 4) + self.run_test(trilModelwithNegDiagonal(), (x)) + + @skipIfUnsupportedMinOpsetVersion(14) + def test_triu(self): + class triuModel(torch.nn.Module): + def forward(self, x): + return torch.triu(x) + + x = torch.randn(2, 3, 4) + self.run_test(triuModel(), (x)) + + class triuModelwithDiagonal(torch.nn.Module): + def forward(self, x): + return torch.triu(x, diagonal=1) + + x = torch.randn(2, 3, 4) + self.run_test(triuModelwithDiagonal(), (x)) + + class trilModelwithNegDiagonal(torch.nn.Module): + def forward(self, x): + return torch.tril(x, diagonal=-1) + + x = torch.randn(2, 3, 4) + self.run_test(trilModelwithNegDiagonal(), (x)) + def test_mish(self): class MishModel(torch.nn.Module): def __init__(self): @@ -7574,44 +7689,75 @@ def test_batchnorm_training(self): class MyModule(torch.nn.Module): def __init__(self): super(MyModule, self).__init__() - self.bn = torch.nn.BatchNorm2d(3, affine=True) + self.bn1 = torch.nn.BatchNorm2d(3, affine=False) + self.cv1 = torch.nn.Conv2d(3, 3, 10) + self.bn2 = torch.nn.BatchNorm2d(3, affine=True) + self.cv2 = torch.nn.Conv2d(3, 3, 10) + self.bn3 = torch.nn.BatchNorm2d(3, affine=False) def forward(self, x): - bn = self.bn(x) - return bn - - model = MyModule() - x = torch.randn(10, 3, 128, 128) + x = self.bn1(x) + x = self.cv1(x) + x = self.bn2(x) + x = self.cv2(x) + x = self.bn3(x) + return x - model.train() - out = model(x) + x = torch.randn(10, 3, 20, 20) * 2 + model_export = MyModule() + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5) + model_export.train() + self.run_test(model_export, (x, ), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5) - # state after 1 train epoch - running_mean = model.bn.running_mean - running_var = model.bn.running_var - saved_mean = x.mean((0, 2, 3)) - saved_var = x.var((0, 2, 3), correction=1) + def test_batchnorm_training_mode_fix_layer(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.bn1 = torch.nn.BatchNorm2d(3, affine=True) + self.cv1 = torch.nn.Conv2d(3, 3, 10) + self.bn2 = torch.nn.BatchNorm2d(3, affine=False) + self.cv2 = torch.nn.Conv2d(3, 3, 10) + self.bn3 = torch.nn.BatchNorm2d(3, affine=True) + self.bn3.eval() - pytorch_out = [out.detach().numpy(), - running_mean.cpu().numpy(), running_var.cpu().numpy(), - saved_mean.cpu().numpy(), saved_var.cpu().numpy()] + def forward(self, x): + x = self.bn1(x) + x = self.cv1(x) + x = self.bn2(x) + x = self.cv2(x) + x = self.bn3(x) + return x + x = torch.randn(10, 3, 128, 128) model_export = MyModule() - f = io.BytesIO() + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5) + model_export.train() + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5) - ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING) - ort_outs = run_ort(ort_sess, input=(x,)) - [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)] + def test_batchnorm_eval_mode_train_layer(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.bn1 = torch.nn.BatchNorm2d(3, affine=True) + self.cv1 = torch.nn.Conv2d(3, 3, 10) + self.bn2 = torch.nn.BatchNorm2d(3, affine=False) + self.cv2 = torch.nn.Conv2d(3, 3, 10) + self.bn3 = torch.nn.BatchNorm2d(3, affine=True) + self.bn3.train() - model_export = torch.jit.script(MyModule()) - ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, - example_outputs=out, - training=torch.onnx.TrainingMode.TRAINING, - onnx_shape_inference=True) - ort_outs = run_ort(ort_sess, input=(x,)) - [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in - zip(pytorch_out, ort_outs)] + def forward(self, x): + x = self.bn1(x) + x = self.cv1(x) + x = self.bn2(x) + x = self.cv2(x) + x = self.bn3(x) + return x + + x = torch.randn(10, 3, 128, 128) + model_export = MyModule() + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-5) + model_export.eval() + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.PRESERVE, rtol=1e-3, atol=1e-5) @skipIfUnsupportedMinOpsetVersion(12) def test_dropout_training(self): @@ -7626,7 +7772,6 @@ def forward(self, x): model = MyModule() x = torch.randn(10) - model.train() ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, @@ -7663,7 +7808,6 @@ def forward(self, x): nb_elements = torch.numel(input) model.train() - ort_sess = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) @@ -7705,29 +7849,10 @@ def forward(self, x): bn = self.bn(x) return bn - model = MyModule() + model_export = MyModule() x = torch.randn(10, 3, 128, 128) - ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING) - ort_outs1 = run_ort(ort_sess1, input=(x,)) - ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, - training=torch.onnx.TrainingMode.EVAL) - ort_outs2 = run_ort(ort_sess2, input=(x,)) - [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in - zip(ort_outs1, ort_outs2)] - - script_model = torch.jit.script(model) - outputs = model(x) - ort_sess1 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, - example_outputs=outputs, - training=torch.onnx.TrainingMode.TRAINING) - ort_outs1 = run_ort(ort_sess1, input=(x,)) - ort_sess2 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, - example_outputs=outputs, - training=torch.onnx.TrainingMode.EVAL) - ort_outs2 = run_ort(ort_sess2, input=(x,)) - [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in - zip(ort_outs1, ort_outs2)] + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL) + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5) def test_multiple_conv_bn(self): class MyModule(torch.nn.Module): @@ -7754,16 +7879,10 @@ def forward(self, x): x = self.relu(x) return x - model = MyModule() + model_export = MyModule() x = torch.randn(2, 3, 224, 224) - ort_sess1 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, - training=torch.onnx.TrainingMode.TRAINING) - ort_outs1 = run_ort(ort_sess1, input=(x,)) - ort_sess2 = convert_to_onnx(model, input=(x,), opset_version=self.opset_version, - training=torch.onnx.TrainingMode.EVAL) - ort_outs2 = run_ort(ort_sess2, input=(x,)) - [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in - zip(ort_outs1, ort_outs2)] + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.TRAINING, rtol=1e-3, atol=1e-5) + self.run_test(model_export, (x,), training=torch.onnx.TrainingMode.EVAL) def test_script_custom_class_error(self): class BoxCoder(object): @@ -9547,5 +9666,12 @@ def setup_rnn_tests(): keep_initializers_as_inputs=False, onnx_shape_inference=True)) +# opset 14 tests +TestONNXRuntime_opset14 = type(str("TestONNXRuntime_opset14"), + (unittest.TestCase,), + dict(TestONNXRuntime.__dict__, opset_version=14, + keep_initializers_as_inputs=False, + onnx_shape_inference=True)) + if __name__ == "__main__": unittest.main() diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 71f52b306b8c4..b87fa06d648a4 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -5,7 +5,8 @@ from torch.onnx import utils, OperatorExportTypes, TrainingMode from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type, _set_onnx_shape_inference import torch.utils.cpp_extension -from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion +from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion, + skipIfUnsupportedMaxOpsetVersion) import caffe2.python.onnx.backend as backend from verify import verify @@ -36,7 +37,10 @@ def _model_to_graph(self, model, input, operator_export_type=OperatorExportTypes.ONNX, input_names=None, dynamic_axes=None): - + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() # Need disable onnx_shape_inference for this test because it puts const node to initializers. _set_onnx_shape_inference(False) utils._validate_dynamic_axes(dynamic_axes, model, None, None) @@ -100,19 +104,20 @@ def forward(self, x, y, t): def test_output_list(self): class PaddingLayer(torch.jit.ScriptModule): @torch.jit.script_method - def forward(self, input_t): - # type: (Tensor) -> Tensor - for i in range(2): + def forward(self, input_t, n): + # type: (Tensor, int) -> Tensor + for i in range(n): input_t = input_t * 2 return input_t input_t = torch.ones(size=[10], dtype=torch.long) + n = 2 model = torch.jit.script(PaddingLayer()) - example_output = model(input_t) + example_output = model(input_t, n) with self.assertRaises(RuntimeError): torch.onnx.export(model, - (input_t, ), + (input_t, n), "test.onnx", opset_version=self.opset_version, example_outputs=[example_output]) @@ -635,7 +640,7 @@ def test_aten_fallthrough(self): # Test aten export of op with no symbolic class Module(torch.nn.Module): def forward(self, x): - return torch.triu(x) + return torch.erfc(x) x = torch.randn(2, 3, 4) _set_opset_version(self.opset_version) @@ -643,8 +648,7 @@ def forward(self, x): operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, input_names=['x'], dynamic_axes={'x': [0, 1, 2]}) iter = graph.nodes() - assert next(iter).kind() == "onnx::Constant" - assert next(iter).kind() == "aten::triu" + assert next(iter).kind() == "aten::erfc" def test_custom_op_fallthrough(self): # Test custom op @@ -731,7 +735,7 @@ def forward(self, x): assert next(iter).kind() == "aten::dequantize" # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11 - @skipIfUnsupportedOpsetVersion([11, 12, 13]) + @skipIfUnsupportedMaxOpsetVersion(10) def test_prim_fallthrough(self): # Test prim op class PrimModule(torch.jit.ScriptModule): @@ -811,11 +815,11 @@ def forward(self, x): model = torch.jit.script(MyModule()) x = torch.randn(10, 3, 128, 128) example_outputs = model(x) - f = io.BytesIO() _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) graph, _, __ = self._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, operator_export_type=OperatorExportTypes.ONNX, + training=torch.onnx.TrainingMode.TRAINING, input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) graph_input_params = [param.debugName() for param in graph.inputs()] @@ -861,6 +865,7 @@ def test_fuse_resnet18(self): model = torchvision.models.resnet18(pretrained=True) x = torch.randn(2, 3, 224, 224, requires_grad=True) graph, _, __ = self._model_to_graph(model, (x, ), + training=TrainingMode.EVAL, input_names=['x'], dynamic_axes={'x': [0, 1, 2, 3]}) for node in graph.nodes(): @@ -880,7 +885,6 @@ def __init__(self): def forward(self, x, y): return f(x, y) - model = MyModule() input_1 = torch.tensor(11) input_2 = torch.tensor(12) _set_opset_version(self.opset_version) @@ -917,20 +921,10 @@ def forward(self, x, y): (TestCase,), dict(TestUtilityFuns.__dict__, opset_version=13)) -# opset 11 tests -TestUtilityFuns_opset11_new_jit_API = type(str("TestUtilityFuns_opset11_new_jit_API"), - (TestCase,), - dict(TestUtilityFuns.__dict__, opset_version=11)) - -# opset 12 tests -TestUtilityFuns_opset12_new_jit_API = type(str("TestUtilityFuns_opset12_new_jit_API"), - (TestCase,), - dict(TestUtilityFuns.__dict__, opset_version=12)) - -# opset 13 tests -TestUtilityFuns_opset13_new_jit_API = type(str("TestUtilityFuns_opset13_new_jit_API"), - (TestCase,), - dict(TestUtilityFuns.__dict__, opset_version=13)) +# opset 14 tests +TestUtilityFuns_opset14 = type(str("TestUtilityFuns_opset14"), + (TestCase,), + dict(TestUtilityFuns.__dict__, opset_version=14)) if __name__ == "__main__": diff --git a/test/package/test_directory_reader.py b/test/package/test_directory_reader.py index 93968d6e1bf92..576a7f0c064cd 100644 --- a/test/package/test_directory_reader.py +++ b/test/package/test_directory_reader.py @@ -61,7 +61,7 @@ def test_loading_pickle(self): importer = PackageImporter(Path(temp_dir) / Path(filename).name) dir_mod = importer.load_pickle("model", "model.pkl") input = torch.rand(1, 3, 224, 224) - self.assertTrue(torch.allclose(dir_mod(input), resnet(input))) + self.assertEqual(dir_mod(input), resnet(input)) def test_loading_module(self): """ diff --git a/test/package/test_model.py b/test/package/test_model.py index f5e08b6bfa83c..dc67ff5d89d2e 100644 --- a/test/package/test_model.py +++ b/test/package/test_model.py @@ -49,7 +49,7 @@ def test_resnet(self): # test that it works input = torch.rand(1, 3, 224, 224) ref = resnet(input) - self.assertTrue(torch.allclose(r2(input), ref)) + self.assertEqual(r2(input), ref) # functions exist also to get at the private modules in each package torchvision = i.import_module("torchvision") @@ -81,7 +81,7 @@ def test_resnet(self): i2 = PackageImporter(f2) r3 = i2.load_pickle("model", "model.pkl") - self.assertTrue(torch.allclose(r3(input), ref)) + self.assertEqual(r3(input), ref) @skipIfNoTorchVision def test_model_save(self): @@ -159,7 +159,7 @@ def load(): r = the_model(input) results.append(r) - self.assertTrue(torch.allclose(*results)) + self.assertEqual(*results) @skipIfNoTorchVision def test_script_resnet(self): @@ -188,7 +188,7 @@ def test_script_resnet(self): loaded = torch.jit.load(f2) input = torch.rand(1, 3, 224, 224) - self.assertTrue(torch.allclose((loaded(input)), resnet(input))) + self.assertEqual(loaded(input), resnet(input)) if __name__ == "__main__": diff --git a/test/package/test_package_fx.py b/test/package/test_package_fx.py index 7f31014a8ec04..64d431c0a3e6b 100644 --- a/test/package/test_package_fx.py +++ b/test/package/test_package_fx.py @@ -36,7 +36,7 @@ def forward(self, x): pi = PackageImporter(f) loaded_traced = pi.load_pickle("model", "model.pkl") input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded_traced(input), traced(input))) + self.assertEqual(loaded_traced(input), traced(input)) def test_package_then_fx(self): from package_a.test_module import SimpleTest @@ -52,7 +52,7 @@ def test_package_then_fx(self): loaded = pi.load_pickle("model", "model.pkl") traced = symbolic_trace(loaded) input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded(input), traced(input))) + self.assertEqual(loaded(input), traced(input)) def test_package_fx_package(self): from package_a.test_module import SimpleTest @@ -87,7 +87,7 @@ def test_package_fx_package(self): loaded2 = pi2.load_pickle("model", "model.pkl") input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded(input), loaded2(input))) + self.assertEqual(loaded(input), loaded2(input)) def test_package_fx_with_imports(self): import package_a.subpackage @@ -158,7 +158,7 @@ def __init__(self, root, graph, info): self.assertEqual(loaded_gm.info, "secret") input_x = torch.randn(3) - self.assertTrue(torch.allclose(loaded_gm(input_x), gm(input_x))) + self.assertEqual(loaded_gm(input_x), gm(input_x)) if __name__ == "__main__": diff --git a/test/package/test_package_script.py b/test/package/test_package_script.py index 3bbaed0501ca1..ecacd79fb6bf7 100644 --- a/test/package/test_package_script.py +++ b/test/package/test_package_script.py @@ -51,7 +51,7 @@ def test_package_interface(self): input = torch.tensor(1) - self.assertTrue(torch.allclose(scripted(input), scripted_loaded(input))) + self.assertEqual(scripted(input), scripted_loaded(input)) def test_different_package_interface(self): """Test a case where the interface defined in the package is @@ -149,7 +149,7 @@ def __init__(self, x): input = torch.rand(2, 3) loaded_script_class = diff_fake.MyScriptClass(input) orig_script_class = fake.MyScriptClass(input) - self.assertTrue(torch.allclose(loaded_script_class.bar, orig_script_class.foo)) + self.assertEqual(loaded_script_class.bar, orig_script_class.foo) def test_save_scriptmodule(self): """ @@ -506,7 +506,7 @@ def test_save_shared_tensors(self): self.assertTrue(len(file_structure.children[".data"].children) == 1) input = torch.rand(2, 3, 4) - self.assertTrue(torch.allclose(loaded_mod_1(input), mod1(input))) + self.assertEqual(loaded_mod_1(input), mod1(input)) def test_load_shared_tensors(self): """ @@ -630,7 +630,7 @@ def test_saving_and_scripting_packaged_mod(self): loaded_mod = importer_0.load_pickle("model", "model.pkl") input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded_mod(input), orig_mod(input))) + self.assertEqual(loaded_mod(input), orig_mod(input)) scripted_mod = torch.jit.script(loaded_mod) @@ -643,7 +643,7 @@ def test_saving_and_scripting_packaged_mod(self): importer_1 = PackageImporter(buffer_1) loaded_mod_scripted = importer_1.load_pickle("res", "scripted_mod.pkl") - self.assertTrue(torch.allclose(loaded_mod_scripted(input), orig_mod(input))) + self.assertEqual(loaded_mod_scripted(input), orig_mod(input)) def test_mixing_packaged_and_inline_modules(self): """ @@ -680,7 +680,7 @@ def forward(self, input: str): loaded_imported = importer.load_pickle("model", "imported.pkl") input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded_imported(input), imported_mod(input))) + self.assertEqual(loaded_imported(input), imported_mod(input)) self.assertEqual(loaded_inline("input"), inline_mod("input")) @skipIfNoTorchVision @@ -721,8 +721,8 @@ def a_non_torch_leaf(a, b): loaded_imported = importer.load_pickle("model", "imported.pkl") input = torch.rand(2, 3) - self.assertTrue(torch.allclose(loaded_imported(input), imported_mod(input))) - self.assertTrue(torch.allclose(loaded_inline(input), inline_mod(input))) + self.assertEqual(loaded_imported(input), imported_mod(input)) + self.assertEqual(loaded_inline(input), inline_mod(input)) def test_tensor_sharing_pickle(self): """Test that saving a ScriptModule and a separately saving a tensor diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 10d5831e87758..51e62174cc081 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -2,11 +2,8 @@ import torch.nn as nn import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq -import torch.nn.intrinsic.quantized._reference as nniqr import torch.nn.quantized as nnq -import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd -import torch.nn.functional as F import torch.quantization from torch.quantization import ( @@ -70,24 +67,21 @@ def test_linear_api(self): [4, 8], [True, False], [True, False], - [True, False], [True, False]) for (batch_size, in_features, out_features, use_bias, - use_fused, per_channel, is_reference) in options: + use_fused, per_channel) in options: self._test_linear_api_impl( batch_size, in_features, out_features, use_bias, use_fused, - per_channel, is_reference) + per_channel) - def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, is_reference): + def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel): if torch.backends.quantized.engine == 'qnnpack': per_channel = False - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True) : nniqr.LinearReLU, - (True, False) : nniq.LinearReLU, - (False, True) : nnqr.Linear, - (False, False) : nnq.Linear, + True: nniq.LinearReLU, + False: nnq.Linear, } W = torch.rand(out_features, in_features).float() @@ -107,10 +101,9 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, B = torch.rand(out_features).float() if use_bias else None scale = 0.5 zero_point = 3 - qlinear = class_map[(use_fused, is_reference)](in_features, out_features) + qlinear = class_map[use_fused](in_features, out_features) - qlinear_copy = qlinear # deepcopy does not work right now - # qlinear_copy = copy.deepcopy(qlinear) + qlinear_copy = copy.deepcopy(qlinear) self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) # Run module with default-initialized parameters. # This tests that the constructor is correct. @@ -127,21 +120,11 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, # Check if the module implementation matches calling the # ops directly - if is_reference: - weight = qlinear._qweight - bias = qlinear._bias - weight_dequant = weight.dequantize() - X_q_dq = X_q.dequantize() - Z_ref = F.linear(X_q_dq, weight_dequant, bias) - if use_fused: - Z_ref = F.relu(Z_ref, inplace=True) - Z_ref = torch.quantize_per_tensor(Z_ref, scale, zero_point, torch.quint8) + W_pack = qlinear._packed_params._packed_params + if use_fused: + Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) else: - W_pack = qlinear._packed_params._packed_params - if use_fused: - Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) - else: - Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) + Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertEqual(Z_ref, Z_q) self.assertTrue( @@ -163,28 +146,24 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, else: self.assertEqual(model_dict[key], loaded_dict[key]) - loaded_qlinear = class_map[(use_fused, is_reference)]( + loaded_qlinear = class_map[use_fused]( in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) - if is_reference: - self.assertEqual(qlinear._qweight, loaded_qlinear._qweight) - self.assertEqual(qlinear._bias, loaded_qlinear._bias) - else: - linear_unpack = torch.ops.quantized.linear_unpack - self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), - linear_unpack(loaded_qlinear._packed_params._packed_params)) + linear_unpack = torch.ops.quantized.linear_unpack + self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), + linear_unpack(loaded_qlinear._packed_params._packed_params)) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) - # make sure loaded_qlinear has the same dir as qlinear since - # scripting the module will add __overloads__ to __dict__ - self.checkScriptable(loaded_qlinear, [[X_q]], check_save_load=True) + # scripting will add __overloads__ to __dict__, which is why we script a copy + # to be able to do the check in the next line + self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) - if not is_reference: - self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) + self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) + # Test serialization b = io.BytesIO() torch.save(qlinear, b) b.seek(0) @@ -193,6 +172,25 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) + # Test copy and deepcopy + copied_linear = copy.copy(qlinear) + self.assertEqual(copied_linear.bias(), qlinear.bias()) + self.assertEqual(copied_linear.scale, qlinear.scale) + self.assertEqual(copied_linear.zero_point, + qlinear.zero_point) + Y_copied = copied_linear(X_q) + np.testing.assert_array_almost_equal( + Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) + + deepcopied_linear = copy.deepcopy(qlinear) + self.assertEqual(deepcopied_linear.bias(), qlinear.bias()) + self.assertEqual(deepcopied_linear.scale, qlinear.scale) + self.assertEqual(deepcopied_linear.zero_point, + qlinear.zero_point) + Y_deepcopied = copied_linear(X_q) + np.testing.assert_array_almost_equal( + Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) + # Test JIT self.checkScriptable(qlinear, [[X_q]], check_save_load=True) @@ -230,12 +228,11 @@ def test_quant_dequant_api(self): self.assertEqual(rqr, rqr2) def _test_conv_api_impl( - self, module_name, qconv_module, conv_module, batch_size, - in_channels_per_group, input_feature_map_size, out_channels_per_group, - groups, kernel_size, stride, padding, padding_mode, dilation, - X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, - use_bias, use_fused, use_channelwise, is_reference - ): + self, module_name, qconv_module, conv_module, batch_size, + in_channels_per_group, input_feature_map_size, out_channels_per_group, + groups, kernel_size, stride, padding, padding_mode, dilation, + X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, + use_bias, use_fused, use_channelwise): for i in range(len(kernel_size)): assume(input_feature_map_size[i] + 2 * padding[i] >= dilation[i] * (kernel_size[i] - 1) + 1) @@ -264,8 +261,7 @@ def _test_conv_api_impl( # Test members self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name()) - if not is_reference: - self.assertTrue(hasattr(qconv_module, '_packed_params')) + self.assertTrue(hasattr(qconv_module, '_packed_params')) self.assertTrue(hasattr(qconv_module, 'scale')) self.assertTrue(hasattr(qconv_module, 'zero_point')) @@ -294,9 +290,8 @@ def _test_conv_api_impl( # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is # 4 assuming the rounding mode is round-to-nearest, ties-to-even. # skip numerics checking for reference module - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) # Test serialization of quantized Conv Module using state_dict model_dict = qconv_module.state_dict() @@ -316,8 +311,7 @@ def _test_conv_api_impl( self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module)) self.assertTrue(module_name == loaded_qconv_module._get_name()) - if not is_reference: - self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) + self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias')) self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight()) @@ -327,9 +321,8 @@ def _test_conv_api_impl( self.assertEqual(qconv_module.zero_point, loaded_qconv_module.zero_point) Y_loaded = loaded_qconv_module(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) # Test serialization b = io.BytesIO() @@ -349,9 +342,8 @@ def _test_conv_api_impl( self.assertEqual(copied_conv.zero_point, qconv_module.zero_point) Y_copied = copied_conv(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) deepcopied_conv = copy.deepcopy(qconv_module) self.assertEqual(deepcopied_conv.bias(), qconv_module.bias()) @@ -359,9 +351,8 @@ def _test_conv_api_impl( self.assertEqual(deepcopied_conv.zero_point, qconv_module.zero_point) Y_deepcopied = copied_conv(X_q) - if not is_reference: - np.testing.assert_array_almost_equal( - Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) + np.testing.assert_array_almost_equal( + Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) # JIT testing self.checkScriptable( @@ -396,9 +387,8 @@ def test_conv1d_api(self): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options: + for pad_mode, use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -426,15 +416,13 @@ def test_conv1d_api(self): Y_zero_point = 4 if torch.backends.quantized.engine == 'qnnpack': use_channelwise = False - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU1d, "QuantizedConvReLU1d(Reference)"), - (True, False): (nniq.ConvReLU1d, "QuantizedConvReLU1d"), - (False, True): (nnqr.Conv1d, "QuantizedConv1d(Reference)"), - (False, False): (nnq.Conv1d, "QuantizedConv1d") + True: (nniq.ConvReLU1d, "QuantizedConvReLU1d"), + False: (nnq.Conv1d, "QuantizedConv1d") } - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel, stride, pad, dilation, groups, use_bias, padding_mode=pad_mode @@ -453,7 +441,7 @@ def test_conv1d_api(self): in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, pad, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, - Y_zero_point, use_bias, use_fused, use_channelwise, is_reference) + Y_zero_point, use_bias, use_fused, use_channelwise) @override_qengines def test_conv2d_api(self): @@ -462,9 +450,8 @@ def test_conv2d_api(self): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for pad_mode, use_bias, use_fused, use_channelwise, is_reference in options: + for pad_mode, use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -494,15 +481,13 @@ def test_conv2d_api(self): W_zero_point = [3] Y_scale = 5.0 Y_zero_point = 4 - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU2d, "QuantizedConvReLU2d(Reference)"), - (True, False): (nniq.ConvReLU2d, "QuantizedConvReLU2d"), - (False, True): (nnqr.Conv2d, "QuantizedConv2d(Reference)"), - (False, False): (nnq.Conv2d, "QuantizedConv2d") + True: (nniq.ConvReLU2d, "QuantizedConvReLU2d"), + False: (nnq.Conv2d, "QuantizedConv2d") } - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode @@ -521,7 +506,7 @@ def test_conv2d_api(self): in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, stride, padding, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, - Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise, is_reference) + Y_scale, Y_zero_point, use_bias, use_fused, use_channelwise) @skipIfNoFBGEMM def test_conv3d_api(self): @@ -529,9 +514,8 @@ def test_conv3d_api(self): [True, False], # use_bias [True, False], # use_fused [True, False], # use_channelwise - [True, False] # is_reference ) - for use_bias, use_fused, use_channelwise, is_reference in options: + for use_bias, use_fused, use_channelwise in options: if torch.backends.quantized.engine == "qnnpack": use_channelwise = False batch_size = 2 @@ -566,16 +550,14 @@ def test_conv3d_api(self): W_zero_point = [3] Y_scale = 5.0 Y_zero_point = 4 - # (use_fused, is_reference) -> quantized class + # use_fused -> quantized class class_map = { - (True, True): (nniqr.ConvReLU3d, "QuantizedConvReLU3d(Reference)"), - (True, False): (nniq.ConvReLU3d, "QuantizedConvReLU3d"), - (False, True): (nnqr.Conv3d, "QuantizedConv3d(Reference)"), - (False, False): (nnq.Conv3d, "QuantizedConv3d") + True: (nniq.ConvReLU3d, "QuantizedConvReLU3d"), + False: (nnq.Conv3d, "QuantizedConv3d") } with override_quantized_engine('fbgemm'): - qconv_cls, module_name = class_map[(use_fused, is_reference)] + qconv_cls, module_name = class_map[use_fused] qconv_module = qconv_cls( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, use_bias, padding_mode=pad_mode @@ -595,7 +577,7 @@ def test_conv3d_api(self): out_channels_per_group, groups, kernel_size, stride, padding, pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias, use_fused, - use_channelwise, is_reference) + use_channelwise) def test_pool_api(self): """Tests the correctness of the pool module. diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d0a2dea45e8e3..49b7c96847612 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -1617,8 +1617,8 @@ def test_qtopk(self, X, k, dim, largest, sorted): quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted) assert(len(unquantized_out) == len(quantized_out)) - torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0]) - torch.testing.assert_allclose(quantized_out[1], unquantized_out[1]) + torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) + torch.testing.assert_close(quantized_out[1], unquantized_out[1]) @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, min_side=1, max_side=10), @@ -1643,8 +1643,8 @@ def test_qtopk_nhwc(self, X, k, dim, largest, sorted): quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted) assert(len(unquantized_out) == len(quantized_out)) - torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0]) - torch.testing.assert_allclose(quantized_out[1], unquantized_out[1]) + torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) + torch.testing.assert_close(quantized_out[1], unquantized_out[1]) """Tests quantize concatenation (both fused and not).""" @@ -1846,7 +1846,7 @@ def test_cat_nhwc(self, X, relu): else: out = torch.ops.quantized.cat([qX, qY], dim=1, scale=scale, zero_point=zero_point) - torch.testing.assert_allclose(out.dequantize(), ref.dequantize()) + torch.testing.assert_close(out.dequantize(), ref.dequantize()) self.assertNotEqual(out.stride(), sorted(out.stride())) @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=1, max_dims=5, @@ -2414,6 +2414,9 @@ def test_custom_module_lstm(self): custom_module_config = { 'float_to_observed_custom_module_class': { torch.nn.LSTM: torch.nn.quantizable.LSTM + }, + 'observed_to_quantized_custom_module_class': { + torch.nn.quantizable.LSTM: torch.nn.quantizable.LSTM } } @@ -2460,7 +2463,8 @@ def test_custom_module_lstm(self): self.assertEqual(y_ref, y) # Quantize - lstm_quantized = torch.quantization.convert(lstm_prepared) + lstm_quantized = torch.quantization.convert( + lstm_prepared, convert_custom_config_dict=custom_module_config) qy = lstm_quantized(qx) snr = _snr(y, qy) @@ -2602,7 +2606,6 @@ class TestDynamicQuantizedLinear(TestCase): def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, use_relu, use_multi_dim_input, use_channelwise, reduce_range): if torch.backends.quantized.engine == 'qnnpack': - use_relu = False reduce_range = False qlinear_prepack = torch.ops.quantized.linear_prepack @@ -2779,6 +2782,38 @@ def test_qlinear_legacy(self, batch_size, input_channels, output_channels): self.assertEqual(Y_fp32, Y_fp32_ref, msg="torch.ops.quantized.fbgemm_linear_dynamic results are off") + @skipIfNoFBGEMM + def test_qlinear_dynamic_fp16(self): + + options = itertools.product( + (2, 4), # batch_size + (4, 5, 12), # input_channels + (4, 7, 8), # output_channels + (True, False), # use_bias + (True, False), # use_relu + ) + for batch_size, input_channels, output_channels, use_bias, use_relu in options: + qlinear_prepack = torch.ops.quantized.linear_prepack_fp16 + if use_relu: + qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16 + else: + qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16 + + x = torch.randn(batch_size, input_channels) + w = torch.randn(output_channels, input_channels) + bias = torch.randn(output_channels) if use_bias else None + + w_packed = qlinear_prepack(w, bias) + out = qlinear_dynamic(x, w_packed) + + # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors + # output is FP32 + w_fp16 = w.to(torch.float16).to(torch.float32) + ref = F.linear(x, w_fp16, bias) + if use_relu: + ref.relu_() + + self.assertEqual(out, ref) class TestDynamicQuantizedRNNOp(TestCase): """Tests the correctness of the dynamic quantized lstm/gru.""" @@ -3314,6 +3349,9 @@ def embedding_bag_rowwise_offsets_run( if bit_rate == 4: pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack + elif bit_rate == 2: + pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets + pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack weights = torch.from_numpy((np.random.random_sample(( num_embeddings, embedding_dim)) + 1).astype(np.float32)) @@ -3400,8 +3438,7 @@ def get_reference_result( num_embeddings, embedding_dim, include_last_offset, weights, per_sample_weights, indices, offsets) - torch.testing.assert_allclose(reference_result, result, atol=atol, - rtol=rtol) + torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol) if bit_rate == 8 or bit_rate == 4: @@ -3424,7 +3461,7 @@ def get_reference_result( per_sample_weights=per_sample_weights, compressed_indices_mapping=torch.tensor(mapping_table), include_last_offset=include_last_offset) - torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol) + torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol) @@ -3480,6 +3517,33 @@ def test_embedding_bag_4bit(self, num_embeddings, sparsity=sparsity, atol=0.1, rtol=1e-2) + """ Tests the correctness of the embedding_bag_2bit quantized operator """ + @given(num_embeddings=st.integers(10, 100), + embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0), + num_offsets=st.integers(1, 20), + use_32bit_indices=st.booleans(), + use_32bit_offsets=st.booleans(), + enable_per_sample_weights=st.booleans(), + include_last_offset=st.booleans(), + fallback_to_no_sparse=st.booleans(), + sparsity=st.sampled_from([0.0, 0.5, 0.7])) + def test_embedding_bag_2bit(self, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, + use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity): + self.embedding_bag_rowwise_offsets_run(2, num_embeddings, + embedding_dim, num_offsets, + use_32bit_indices, use_32bit_offsets, + enable_per_sample_weights, + include_last_offset, + fallback_to_no_sparse, + sparsity=sparsity, + atol=1.0, rtol=1e-1) + """ Tests the correctness of the quantized embedding lookup operator """ @given(num_embeddings=st.integers(10, 100), embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0)) @@ -3510,7 +3574,7 @@ def test_embedding_byte(self, num_embeddings, embedding_dim): qresult = quant_op(packed_weight, indices, pruned_weights=False) ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) - torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3) + torch.testing.assert_close(ref, qresult, atol=0.005, rtol=1e-3) def test_embedding_2d_indices(self): @@ -3533,7 +3597,7 @@ def test_embedding_2d_indices(self): qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) packed_weight = prepack_op(qweight) qresult = quant_op(packed_weight, indices, pruned_weights=False) - torch.testing.assert_allclose(ref, qresult, atol=0.05, rtol=1e-3) + torch.testing.assert_close(ref, qresult, atol=0.05, rtol=1e-3) def test_embedding_bag_2d_indices(self): """ @@ -3555,7 +3619,7 @@ def test_embedding_bag_2d_indices(self): pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack q_weights = pt_prepack_op(weights) qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False) - torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3) # Test TorchBind based embedding_bag operator obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) @@ -3569,7 +3633,7 @@ def test_embedding_bag_2d_indices(self): packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0) - torch.testing.assert_allclose(result, qresult, atol=0.05, rtol=1e-3) + torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3) class TestQuantizedConv(TestCase): diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 2298653e878f8..b7782ecf9c1bd 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -205,11 +205,11 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): if reduce_range: ref_scales = [s * 255 / 127 for s in ref_scales] ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] - self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001)) + self.assertEqual(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), rtol=1e-5, atol=0.0001) if qscheme == torch.per_channel_affine_float_qparams: - self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1)) + self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), rtol=1e-5, atol=1) else: - self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) + self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)) # Test for serializability diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index 9fcf5ac138f3e..60cd04345be85 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -312,13 +312,13 @@ def test_forward_per_tensor_half_precision_numerics(self): X1 = torch.randn(5, 5).to(torch.float16) Y1 = torch.fake_quantize_per_tensor_affine(X1, scale, zero, mini, maxi) Y1r = _fake_quantize_per_tensor_affine_reference(X1, scale, zero, mini, maxi) - self.assertTrue(torch.allclose(Y1, Y1r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance) # to force overflow X2 = torch.tensor(2**15 + .01).to(torch.float16) Y2 = torch.fake_quantize_per_tensor_affine(X2, scale, zero, mini, maxi) Y2r = _fake_quantize_per_tensor_affine_reference(X2, scale, zero, mini, maxi) - self.assertTrue(torch.allclose(Y2, Y2r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance) scale = 10 @@ -326,7 +326,7 @@ def test_forward_per_tensor_half_precision_numerics(self): X3 = torch.tensor(2**-24).to(torch.float16) Y3 = torch.fake_quantize_per_tensor_affine(X3, scale, zero, mini, maxi) Y3r = _fake_quantize_per_tensor_affine_reference(X3, scale, zero, mini, maxi) - self.assertTrue(torch.allclose(Y3, Y3r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance) def _test_forward_per_tensor_cachemask_impl(self, device): float_types = (torch.float32, torch.float16, torch.float64) @@ -347,7 +347,7 @@ def _test_forward_per_tensor_cachemask_impl(self, device): X, scale, zero_point, quant_min, quant_max) Y_ref = _fake_quantize_per_tensor_affine_reference( X, scale, zero_point, quant_min, quant_max).to(device) - self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance) self.assertTrue(Y_test.dtype == float_type) def test_forward_per_tensor_cachemask_cpu(self): @@ -380,14 +380,14 @@ def _test_backward_per_tensor_cachemask_impl(self, device): X, scale, zero_point, quant_min, quant_max) Y_ref = _fake_quantize_per_tensor_affine_reference( X, scale, zero_point, quant_min, quant_max).to(device) - self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance) # backward pass dout = torch.rand_like(X, dtype=torch.float).to(device) dX = _fake_quantize_per_tensor_affine_grad_reference( dout, X, scale, zero_point, quant_min, quant_max) Y_test.backward(dout) - self.assertTrue(torch.allclose(dX, X.grad)) + self.assertEqual(dX, X.grad) self.assertTrue(X.grad.dtype == float_type) def test_backward_per_tensor_cachemask_cpu(self): @@ -729,14 +729,14 @@ def test_forward_per_channel_half_precision_numerics(self): X1 = torch.randn(4, 5).to(torch.float16) Y1 = torch.fake_quantize_per_channel_affine(X1, scale, zero, axis, mini, maxi) Y1r = _fake_quantize_per_channel_affine_reference(X1, scale, zero, axis, mini, maxi) - self.assertTrue(torch.allclose(Y1, Y1r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance) # to force overflow X2 = torch.randn(4, 5).to(torch.float16) X2[0, 0] = 2**15 + .01 Y2 = torch.fake_quantize_per_channel_affine(X2, scale, zero, axis, mini, maxi) Y2r = _fake_quantize_per_channel_affine_reference(X2, scale, zero, axis, mini, maxi) - self.assertTrue(torch.allclose(Y2, Y2r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance) scale = torch.zeros(5) + 10 @@ -745,7 +745,7 @@ def test_forward_per_channel_half_precision_numerics(self): X3[0, 0] = 2**-24 Y3 = torch.fake_quantize_per_channel_affine(X3, scale, zero, axis, mini, maxi) Y3r = _fake_quantize_per_channel_affine_reference(X3, scale, zero, axis, mini, maxi) - self.assertTrue(torch.allclose(Y3, Y3r, rtol=tolerance, atol=tolerance)) + self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance) def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis): r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op. @@ -1160,7 +1160,7 @@ def test_fused_obs_fake_quant_backward_op(self, device) -> None: dX = _fake_quantize_per_tensor_affine_grad_reference( dout, x, x_scale, x_zero_point, 0, 255) - self.assertTrue(torch.allclose(dX, x.grad)) + self.assertEqual(dX, x.grad) self.assertTrue(x.grad.dtype == torch.float32) @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),) @@ -1206,7 +1206,7 @@ def test_fused_backward_op_fake_quant_off(self, device) -> None: dX = _fake_quantize_per_tensor_affine_grad_reference( dout, x, x_scale, x_zero_point, 0, 255) - self.assertTrue(torch.allclose(dX, x.grad)) + self.assertEqual(dX, x.grad) self.assertTrue(x.grad.dtype == torch.float32) if __name__ == '__main__': diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index 1824da514b733..10cbd928b2b36 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -42,6 +42,7 @@ EmbeddingBagModule, EmbeddingModule, EmbeddingWithLinear, + LinearReluLinearModel, ) # annotated models @@ -995,6 +996,23 @@ def checkQuantized(model): model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) + def test_linear_relu_fusion(self): + dtype = torch.qint8 + model = LinearReluLinearModel().eval() + qconfig = default_dynamic_qconfig + qconfig_dict = {'' : qconfig} + torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True) + prepare_dynamic(model, qconfig_dict) + convert_dynamic(model) + + def checkQuantized(model): + self.checkDynamicQuantizedLinearRelu(model.fc1, dtype) + self.checkDynamicQuantizedLinear(model.fc2, dtype) + self.checkScriptable(model, self.calib_data, check_save_load=True) + self.checkNoQconfig(model) + + checkQuantized(model) + @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), dtype=st.sampled_from([torch.qint8, torch.float16])) def test_quantized_rnn(self, qconfig, dtype): diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index 7c17d1296daac..a74b1744e7cc3 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -217,10 +217,10 @@ def test_input_weight_eq_observer(self, ndim, input_qdtype, input_qscheme, weigh ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0 ref_zero_points = ref_zero_points - np.round(ref_min_weights_scaled / ref_scales) - self.assertTrue(torch.allclose(weight_qparams[0], torch.tensor( - ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001)) - self.assertTrue(torch.allclose(weight_qparams[1], torch.tensor( - ref_zero_points, dtype=weight_qparams[1].dtype), atol=1)) + self.assertEqual(weight_qparams[0], torch.tensor( + ref_scales, dtype=weight_qparams[0].dtype), rtol=1e-5, atol=0.0001) + self.assertEqual(weight_qparams[1], torch.tensor( + ref_zero_points, dtype=weight_qparams[1].dtype), rtol=1e-5, atol=1) def test_input_weight_equalization_prepare(self): """ Tests that graphs created after prepare_fx is as expected @@ -783,7 +783,7 @@ def test_input_weight_equalization_results(self): prepared(x) equalized_and_quantized = convert_fx(prepared) # Check if compile equalized_and_quantized_output = equalized_and_quantized(x) - self.assertTrue(torch.allclose(quantized_output, equalized_and_quantized_output, atol=0.1)) + self.assertEqual(quantized_output, equalized_and_quantized_output, rtol=1e-5, atol=0.1) @skipIfNoFBGEMM def test_selective_equalization(self): diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index d605eba34d922..3e627f5e14419 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -646,7 +646,6 @@ def _op_is_unmatchable(op): # these ops do not have quantized equivalents ops_to_skip = [ torch.bmm, - torch.sum, torch.div, torch.sub, operator.truediv, @@ -662,6 +661,9 @@ def _op_is_unmatchable(op): # RNNDynamicQuantizeHandler pass elif qhandler_cls == qp.DefaultNodeQuantizeHandler: + # torch.sum does not have quantized equivalents + if base_op == torch.sum: + continue self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") @@ -1832,8 +1834,8 @@ def test_loggers_preserve_qat_numerics(self): mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger) ref_fp32_ns = mp_ns(datum) ref_int8_ns = mc_ns(datum) - self.assertTrue(torch.allclose(ref_fp32, ref_fp32_ns)) - self.assertTrue(torch.allclose(ref_int8, ref_int8_ns)) + self.assertEqual(ref_fp32, ref_fp32_ns) + self.assertEqual(ref_int8, ref_int8_ns) @skipIfNoFBGEMM def test_shadow_loggers_preserve_qat_numerics(self): @@ -1850,7 +1852,7 @@ def test_shadow_loggers_preserve_qat_numerics(self): mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger) ref_shadow = mc_shadows_mp(datum) - self.assertTrue(torch.allclose(ref_fp32, ref_shadow)) + self.assertEqual(ref_fp32, ref_shadow) class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase): """ diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 7940eb73114c6..9682da14483df 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -3,9 +3,11 @@ import torch.nn.functional as F import torch.nn as nn import torch.nn.quantized as nnq +import torch.nn.quantized._reference as nnqr import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.multiprocessing as mp # graph mode quantization based on fx @@ -314,6 +316,38 @@ def test_qconfig_fused_module(self): self.checkGraphModuleNodes(quantized, expected_node_list=node_list) + def test_problematic_fuse_example(self): + class LinearRelu(nn.Sequential): + def __init__(self): + super().__init__( + nn.Linear(5, 5), + nn.ReLU(), + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin_relu = LinearRelu() + self.linear = nn.Linear(5, 5) + + def forward(self, x): + x = self.lin_relu(x) + x = self.linear(x) + return x + + model = M().eval() + # these qconfigs somehow fail equality where default_qconfig does not + qconfig_dict = { + "": None, + "object_type": [ + (torch.nn.Linear, get_default_qconfig('fbgemm')), + (torch.nn.ReLU, get_default_qconfig('fbgemm')), + ], + } + m = prepare_fx(model, qconfig_dict) + + self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) + def test_fuse_custom_config_dict_validity(self): r""" Verifies that if a user passes an invalid key or makes a typo when @@ -498,7 +532,7 @@ def forward(self, x): Conv1d, conv1d_module_args, (conv1d_input,), - ns.call_module(nn.Conv1d if is_reference else nnq.Conv1d), + ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d), None ), ( @@ -506,7 +540,7 @@ def forward(self, x): Conv2d, conv2d_module_args, (conv2d_input,), - ns.call_module(nn.Conv2d if is_reference else nnq.Conv2d), + ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d), None ), ( @@ -514,7 +548,7 @@ def forward(self, x): Conv3d, conv3d_module_args, (conv3d_input,), - ns.call_module(nn.Conv3d if is_reference else nnq.Conv3d), + ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d), None ), ( @@ -538,7 +572,7 @@ def forward(self, x): LinearModule, (), (linear_module_input,), - ns.call_module(nn.Linear) if is_reference else ns.call_module(nnqd.Linear), + ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear), None, ), ( @@ -546,7 +580,7 @@ def forward(self, x): LinearModule, (), (linear_module_input,), - ns.call_module(nn.Linear if is_reference else nnq.Linear), + ns.call_module(nnqr.Linear if is_reference else nnq.Linear), None, ), ] @@ -575,6 +609,13 @@ def test_conv_linear_reference(self): """ Test quantizing functional conv and linear with reference option """ tests = self._get_conv_linear_test_cases(is_reference=True) + + def _get_keys(prefix, is_dynamic): + all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] + if not is_dynamic: + all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) + return all_keys + for (is_dynamic, ModuleClass, module_constructor_inputs, inputs, quantized_node, weight_prepack_node) in tests: quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC @@ -592,22 +633,26 @@ def test_conv_linear_reference(self): def checkWeightQParams(model): for module_name in ("linear", "conv"): if hasattr(model, module_name): - self.assertTrue(hasattr(qr.get_submodule(module_name), "_weight_qparams")) + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) + self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) - def checkSerDeser(model): + def checkSerDeser(model, is_dynamic): for module_name in ("linear", "conv"): if hasattr(model, module_name): # make sure seralization works state_dict = copy.deepcopy(model.state_dict()) - self.assertTrue(module_name + "._weight_qparams" in state_dict) - + all_keys = _get_keys(module_name, is_dynamic) + for key in all_keys: + self.assertTrue(key in state_dict) # check load_state_dict restores states module = getattr(model, module_name) - prev_scale = module._weight_qparams["scale"] - module._weight_qparams["scale"] = None + prev_scale = module.weight_scale + module.weight_scale = None model.load_state_dict(state_dict) - self.assertTrue(torch.equal(prev_scale, module._weight_qparams["scale"])) + module = getattr(model, module_name) + self.assertTrue(torch.equal(prev_scale, module.weight_scale)) checkWeightQParams(qr) @@ -615,7 +660,7 @@ def checkSerDeser(model): # make sure the qparams are preserved after copy checkWeightQParams(qr) - checkSerDeser(qr) + checkSerDeser(qr, is_dynamic) @skipIfNoFBGEMM def test_dynamic_quant_weight_observer(self): @@ -2807,6 +2852,177 @@ def forward(self, x): m = convert_fx(m, is_reference=True) m(torch.rand(2, 1, 5, 5)) + def test_preserve_tuple(self): + """ Test tuple input type is preserved + """ + from typing import List + + class LSTM(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(50, 50, 1) + + def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]): + h = state[0] + c = state[1] + return self.lstm(inputs, (h, c)) + + m = LSTM().eval() + m = prepare_fx(m, {"": default_qconfig}) + # make sure the arg[1] of lstm module is a tuple + for n in m.graph.nodes: + if n.target == "lstm": + self.assertEqual(type(n.args[1]), tuple) + + def test_lowering(self): + class M(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu(x) + + m = M().eval() + m = prepare_fx(m, {"": default_qconfig}) + m_copy = copy.deepcopy(m) + m = convert_fx(m) + m_ref = convert_fx(m_copy, is_reference=True) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method("dequantize"): 1 + } + node_occurrence_ref = { + ns.call_function(torch.quantize_per_tensor): 2, + ns.call_method("dequantize"): 2 + } + + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) + + @skipIfNoFBGEMM + def test_dynamic_with_fusion(self): + """ + Tests that dynamic quantization APIs work with Linear + Relu fusion + """ + class LinearRelu(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear(x) + return self.relu(x) + + class Linear(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.ones(5, 5) + self.b = torch.zeros(5) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w, self.b) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu()) + self.mods2 = Linear() + self.relu = F.relu + + def forward(self, x): + x = self.mods1(x) + x = self.mods2(x) + x = self.relu(x) + return x + + model = M().eval() + + dynamic_quantized_ops = { + float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16, + default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic + } + for config in [float16_dynamic_qconfig, default_dynamic_qconfig]: + qconfig = { + "": config + } + m = prepare_fx(model, qconfig) + m = convert_fx(m) + m(torch.rand(5, 5)) + node_list = [ + ns.call_module(nniqd.LinearReLU), + ns.call_module(nniqd.LinearReLU), + ns.call_function(dynamic_quantized_ops[config]), + ] + self.checkGraphModuleNodes(m, expected_node_list=node_list) + + def test_ref_linear_module(self): + """ Make sure the numerics for models with ref linear module + matches models with fbgemm/qnnpack module + """ + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + class M2(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + for M in [M1, M2]: + m = M().eval() + m = prepare_fx(m, {"": default_qconfig}) + m_copy = copy.deepcopy(m) + m = convert_fx(m, is_reference=False) + m_ref = convert_fx(m_copy, is_reference=True) + data = torch.randn(5, 10) + result = m(data) + result_ref = m_ref(data) + self.assertTrue(torch.equal(result, result_ref)) + + def test_ref_conv_module(self): + """ Make sure the numerics for models with ref conv module + matches models with fbgemm/qnnpack module + """ + convs = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + } + + class M1(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = convs[dim](3, 3, 3) + + def forward(self, x): + return self.conv(x) + + class M2(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = convs[dim](3, 3, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + for dim, M in itertools.product([1, 2, 3], [M1, M2]): + m = M(dim).eval() + m = prepare_fx(m, {"": default_qconfig}) + m_copy = copy.deepcopy(m) + m = convert_fx(m, is_reference=False) + m_ref = convert_fx(m_copy, is_reference=True) + data = self.img_data_dict[dim][0][0] + result = m(data) + result_ref = m_ref(data) + self.assertTrue(torch.equal(result, result_ref)) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops @@ -2880,7 +3096,7 @@ def forward(self, x): } quant_type_to_qlinear_relu_fun = { # we don't have linear_relu_dynamic - QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), + QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic), QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu), QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu), } @@ -2961,7 +3177,10 @@ def forward(self, x): if is_reference: qlinear_fun = ns.call_function(torch.nn.functional.linear) else: - qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) + if has_relu: + qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16) + else: + qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) prepare_node_occurrence = { # weight ns.call_module(torch.quantization.PlaceholderObserver): 1 @@ -4361,13 +4580,13 @@ def forward(self, x): reference_order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), - ns.call_module(nn.Conv2d), + ns.call_module(nnqr.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ns.call_module(nn.Sigmoid), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), - ns.call_module(nn.Conv2d), + ns.call_module(nnqr.Conv2d), ns.call_function(torch.quantize_per_tensor), ns.call_method('dequantize'), ] @@ -4592,7 +4811,7 @@ def _test_conv_transpose_impl( m2q = torch.quantization.convert(m2p) q_result2 = m2q(data) # verify results match - self.assertTrue(torch.allclose(q_result1, q_result2)) + self.assertEqual(q_result1, q_result2) @unittest.skipUnless('qnnpack' in supported_qengines, "This Pytorch Build has not been built with or does not support QNNPACK") diff --git a/test/quantization/jit/test_deprecated_jit_quant.py b/test/quantization/jit/test_deprecated_jit_quant.py index 662ead35bcf01..68ddb5c346a49 100644 --- a/test/quantization/jit/test_deprecated_jit_quant.py +++ b/test/quantization/jit/test_deprecated_jit_quant.py @@ -99,7 +99,7 @@ def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor: self.assertEqual(len(outs), len(ref_outs)) for out, ref_out in zip(outs, ref_outs): - torch.testing.assert_allclose(out, ref_out) + torch.testing.assert_close(out, ref_out) @skipIfNoFBGEMM def test_rnn_quantized(self): @@ -165,32 +165,32 @@ def test_rnn_quantized(self): # Compare int8 quantized to unquantized output_int8, final_hiddens_int8 = cell_int8(x, hiddens) - torch.testing.assert_allclose(output_int8, ref_out) + torch.testing.assert_close(output_int8, ref_out) for out, ref in zip(final_hiddens_int8, ref_hid): - torch.testing.assert_allclose(out, ref) + torch.testing.assert_close(out, ref) # Compare fp16 quantized to unquantized output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) - torch.testing.assert_allclose(output_fp16, ref_out) + torch.testing.assert_close(output_fp16, ref_out) for out, ref in zip(final_hiddens_fp16, ref_hid): - torch.testing.assert_allclose(out, ref) + torch.testing.assert_close(out, ref) def compare_quantized_unquantized(ScriptWrapper, cell): wrapper = ScriptWrapper(cell) # Compare quantize scripted module to unquantized script_out, script_hid = wrapper(x, hiddens) - torch.testing.assert_allclose(script_out, ref_out) + torch.testing.assert_close(script_out, ref_out) for out, ref in zip(script_hid, ref_hid): - torch.testing.assert_allclose(out, ref) + torch.testing.assert_close(out, ref) # Compare export/import to unquantized export_import_wrapper = self.getExportImportCopyWithPacking(wrapper) ei_out, ei_hid = export_import_wrapper(x, hiddens) - torch.testing.assert_allclose(ei_out, ref_out) + torch.testing.assert_close(ei_out, ref_out) for out, ref in zip(ei_hid, ref_hid): - torch.testing.assert_allclose(out, ref) + torch.testing.assert_close(out, ref) if isinstance(cell, torch.jit.quantized.QuantizedGRU): class ScriptWrapper(torch.jit.ScriptModule): @@ -252,8 +252,8 @@ def forward(self, x): fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16) y_fp16 = fb_fp16(value) - torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3) - torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3) + torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3) + torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3) @skipIfNoFBGEMM def test_erase_class_tensor_shapes(self): diff --git a/test/run_test.py b/test/run_test.py old mode 100755 new mode 100644 index e40f580bbe9e7..5d3856ba3e144 --- a/test/run_test.py +++ b/test/run_test.py @@ -4,8 +4,8 @@ import copy from datetime import datetime from distutils.util import strtobool -import modulefinder import os +import pathlib import shutil import signal import subprocess @@ -14,13 +14,21 @@ import torch from torch.utils import cpp_extension -from torch.testing._internal.common_utils import FILE_SCHEMA, IS_IN_CI, TEST_WITH_ROCM, shell, set_cwd +from torch.testing._internal.common_utils import ( + FILE_SCHEMA, + IS_IN_CI, + TEST_WITH_ROCM, + shell, + set_cwd, +) import torch.distributed as dist from typing import Dict, Optional, List +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent + try: # using tools/ to optimize test run. - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + sys.path.append(str(REPO_ROOT)) from tools.testing.test_selections import ( export_S3_test_times, get_shard_based_on_S3, @@ -29,354 +37,233 @@ get_reordered_tests, get_test_case_configs, ) + from tools.testing.modulefinder_determinator import ( + should_run_test, + TARGET_DET_LIST, + ) + HAVE_TEST_SELECTION_TOOLS = True except ImportError: HAVE_TEST_SELECTION_TOOLS = False - print("Unable to import test_selections from tools/testing. Running without test selection stats...") - - -TESTS = [ - 'test_import_time', - 'test_public_bindings', - 'test_type_hints', - 'test_ao_sparsity', - 'test_autograd', - 'benchmark_utils/test_benchmark_utils', - 'test_binary_ufuncs', - 'test_buffer_protocol', - 'test_bundled_inputs', - 'test_complex', - 'test_cpp_api_parity', - 'test_cpp_extensions_aot_no_ninja', - 'test_cpp_extensions_aot_ninja', - 'test_cpp_extensions_jit', - 'distributed/test_c10d_common', - 'distributed/test_c10d_gloo', - 'distributed/test_c10d_nccl', - 'distributed/test_jit_c10d', - 'distributed/test_c10d_spawn_gloo', - 'distributed/test_c10d_spawn_nccl', - 'distributed/test_store', - 'distributed/test_pg_wrapper', - 'distributed/algorithms/test_join', - 'test_cuda', - 'test_jit_cuda_fuser', - 'test_cuda_primary_ctx', - 'test_dataloader', - 'test_datapipe', - 'distributed/test_data_parallel', - 'distributed/test_distributed_fork', - 'distributed/test_distributed_spawn', - 'distributions/test_constraints', - 'distributions/test_distributions', - 'test_dispatch', - 'test_foreach', - 'test_indexing', - 'test_jit', - 'test_linalg', - 'test_logging', - 'test_mkldnn', - 'test_model_dump', - 'test_module_init', - 'test_modules', - 'test_multiprocessing', - 'test_multiprocessing_spawn', - 'distributed/test_nccl', - 'test_native_functions', - 'test_numba_integration', - 'test_nn', - 'test_ops', - 'test_optim', - 'test_functional_optim', - 'test_pytree', - 'test_mobile_optimizer', - 'test_set_default_mobile_cpu_allocator', - 'test_xnnpack_integration', - 'test_vulkan', - 'test_sparse', - 'test_sparse_csr', - 'test_quantization', - 'test_pruning_op', - 'test_spectral_ops', - 'test_serialization', - 'test_shape_ops', - 'test_show_pickle', - 'test_sort_and_select', - 'test_tensor_creation_ops', - 'test_testing', - 'test_torch', - 'test_type_info', - 'test_unary_ufuncs', - 'test_utils', - 'test_view_ops', - 'test_vmap', - 'test_namedtuple_return_api', - 'test_numpy_interop', - 'test_jit_profiling', - 'test_jit_legacy', - 'test_jit_fuser_legacy', - 'test_tensorboard', - 'test_namedtensor', - 'test_reductions', - 'test_type_promotion', - 'test_jit_disabled', - 'test_function_schema', - 'test_overrides', - 'test_jit_fuser_te', - 'test_tensorexpr', - 'test_tensorexpr_pybind', - 'test_openmp', - 'test_profiler', - "distributed/test_launcher", - 'distributed/nn/jit/test_instantiator', - 'distributed/rpc/test_faulty_agent', - 'distributed/rpc/test_tensorpipe_agent', - 'distributed/rpc/cuda/test_tensorpipe_agent', - 'test_determination', - 'test_futures', - 'test_fx', - 'test_fx_experimental', - 'test_functional_autograd_benchmark', - 'test_package', - 'test_license', - 'distributed/pipeline/sync/skip/test_api', - 'distributed/pipeline/sync/skip/test_gpipe', - 'distributed/pipeline/sync/skip/test_inspect_skip_layout', - 'distributed/pipeline/sync/skip/test_leak', - 'distributed/pipeline/sync/skip/test_portal', - 'distributed/pipeline/sync/skip/test_stash_pop', - 'distributed/pipeline/sync/skip/test_tracker', - 'distributed/pipeline/sync/skip/test_verify_skippables', - 'distributed/pipeline/sync/test_balance', - 'distributed/pipeline/sync/test_bugs', - 'distributed/pipeline/sync/test_checkpoint', - 'distributed/pipeline/sync/test_copy', - 'distributed/pipeline/sync/test_deferred_batch_norm', - 'distributed/pipeline/sync/test_dependency', - 'distributed/pipeline/sync/test_inplace', - 'distributed/pipeline/sync/test_microbatch', - 'distributed/pipeline/sync/test_phony', - 'distributed/pipeline/sync/test_pipe', - 'distributed/pipeline/sync/test_pipeline', - 'distributed/pipeline/sync/test_stream', - 'distributed/pipeline/sync/test_transparency', - 'distributed/pipeline/sync/test_worker', - 'distributed/optim/test_zero_redundancy_optimizer', - 'distributed/elastic/timer/api_test', - 'distributed/elastic/timer/local_timer_example', - 'distributed/elastic/timer/local_timer_test', - 'distributed/elastic/events/lib_test', - 'distributed/elastic/metrics/api_test', - 'distributed/elastic/utils/logging_test', - 'distributed/elastic/utils/util_test', - 'distributed/elastic/utils/distributed_test', - 'distributed/elastic/multiprocessing/api_test', - 'distributed/_sharding_spec/test_sharding_spec', - 'distributed/_sharded_tensor/test_sharded_tensor', -] + print( + "Unable to import test_selections from tools/testing. Running without test selection stats..." + ) + + +def discover_tests( + base_dir: Optional[pathlib.Path] = None, + blocklisted_patterns: Optional[List[str]] = None, + blocklisted_tests: Optional[List[str]] = None, + extra_tests: Optional[List[str]] = None) -> List[str]: + """ + Searches for all python files starting with test_ excluding one specified by patterns + """ + def skip_test_p(name: str) -> bool: + rc = False + if blocklisted_patterns is not None: + rc |= any(name.startswith(pattern) for pattern in blocklisted_patterns) + if blocklisted_tests is not None: + rc |= name in blocklisted_tests + return rc + cwd = pathlib.Path(__file__).resolve().parent if base_dir is None else base_dir + all_py_files = list(cwd.glob('**/test_*.py')) + rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files] + # Invert slashes on Windows + if sys.platform == "win32": + rc = [name.replace('\\', '/') for name in rc] + rc = [test for test in rc if not skip_test_p(test)] + if extra_tests is not None: + rc += extra_tests + return sorted(rc) + + +TESTS = discover_tests( + blocklisted_patterns=[ + 'ao', + 'bottleneck_test', + 'custom_backend', + 'custom_operator', + 'fx', # executed by test_fx.py + 'jit', # executed by test_jit.py + 'mobile', + 'onnx', + 'package', # executed by test_package.py + 'quantization', # executed by test_quantization.py + ], + blocklisted_tests=[ + 'test_bundled_images', + 'test_cpp_extensions_aot', + 'test_gen_backend_stubs', + 'test_jit_fuser', + 'test_jit_simple', + 'test_jit_string', + 'test_kernel_launch_checks', + 'test_metal', + 'test_nnapi', + 'test_python_dispatch', + 'test_segment_reductions', + 'test_static_runtime', + 'test_throughput_benchmark', + 'test_typing', + "distributed/algorithms/ddp_comm_hooks/test_ddp_hooks", + "distributed/algorithms/quantization/test_quantization", + "distributed/bin/test_script", + "distributed/elastic/multiprocessing/bin/test_script", + "distributed/launcher/bin/test_script", + "distributed/launcher/bin/test_script_init_method", + "distributed/launcher/bin/test_script_is_torchelastic_launched", + "distributed/launcher/bin/test_script_local_rank", + "distributed/test_c10d_spawn", + 'distributions/test_transforms', + 'distributions/test_utils', + ], + extra_tests=[ + "test_cpp_extensions_aot_ninja", + "test_cpp_extensions_aot_no_ninja", + "distributed/elastic/timer/api_test", + "distributed/elastic/timer/local_timer_example", + "distributed/elastic/timer/local_timer_test", + "distributed/elastic/events/lib_test", + "distributed/elastic/metrics/api_test", + "distributed/elastic/utils/logging_test", + "distributed/elastic/utils/util_test", + "distributed/elastic/utils/distributed_test", + "distributed/elastic/multiprocessing/api_test", + ] +) # Tests need to be run with pytest. USE_PYTEST_LIST = [ - 'distributed/pipeline/sync/skip/test_api', - 'distributed/pipeline/sync/skip/test_gpipe', - 'distributed/pipeline/sync/skip/test_inspect_skip_layout', - 'distributed/pipeline/sync/skip/test_leak', - 'distributed/pipeline/sync/skip/test_portal', - 'distributed/pipeline/sync/skip/test_stash_pop', - 'distributed/pipeline/sync/skip/test_tracker', - 'distributed/pipeline/sync/skip/test_verify_skippables', - 'distributed/pipeline/sync/test_balance', - 'distributed/pipeline/sync/test_bugs', - 'distributed/pipeline/sync/test_checkpoint', - 'distributed/pipeline/sync/test_copy', - 'distributed/pipeline/sync/test_deferred_batch_norm', - 'distributed/pipeline/sync/test_dependency', - 'distributed/pipeline/sync/test_inplace', - 'distributed/pipeline/sync/test_microbatch', - 'distributed/pipeline/sync/test_phony', - 'distributed/pipeline/sync/test_pipe', - 'distributed/pipeline/sync/test_pipeline', - 'distributed/pipeline/sync/test_stream', - 'distributed/pipeline/sync/test_transparency', - 'distributed/pipeline/sync/test_worker', - 'distributions/test_constraints', - 'distributions/test_transforms', - 'distributions/test_utils', - 'test_typing', + "distributed/pipeline/sync/skip/test_api", + "distributed/pipeline/sync/skip/test_gpipe", + "distributed/pipeline/sync/skip/test_inspect_skip_layout", + "distributed/pipeline/sync/skip/test_leak", + "distributed/pipeline/sync/skip/test_portal", + "distributed/pipeline/sync/skip/test_stash_pop", + "distributed/pipeline/sync/skip/test_tracker", + "distributed/pipeline/sync/skip/test_verify_skippables", + "distributed/pipeline/sync/test_balance", + "distributed/pipeline/sync/test_bugs", + "distributed/pipeline/sync/test_checkpoint", + "distributed/pipeline/sync/test_copy", + "distributed/pipeline/sync/test_deferred_batch_norm", + "distributed/pipeline/sync/test_dependency", + "distributed/pipeline/sync/test_inplace", + "distributed/pipeline/sync/test_microbatch", + "distributed/pipeline/sync/test_phony", + "distributed/pipeline/sync/test_pipe", + "distributed/pipeline/sync/test_pipeline", + "distributed/pipeline/sync/test_stream", + "distributed/pipeline/sync/test_transparency", + "distributed/pipeline/sync/test_worker", + "distributions/test_constraints", + "distributions/test_transforms", + "distributions/test_utils", + "test_typing", "distributed/elastic/events/lib_test", "distributed/elastic/agent/server/test/api_test", ] WINDOWS_BLOCKLIST = [ - 'distributed/nn/jit/test_instantiator', - 'distributed/rpc/test_faulty_agent', - 'distributed/rpc/test_tensorpipe_agent', - 'distributed/rpc/cuda/test_tensorpipe_agent', - 'distributed/test_distributed_fork', - 'distributed/pipeline/sync/skip/test_api', - 'distributed/pipeline/sync/skip/test_gpipe', - 'distributed/pipeline/sync/skip/test_inspect_skip_layout', - 'distributed/pipeline/sync/skip/test_leak', - 'distributed/pipeline/sync/skip/test_portal', - 'distributed/pipeline/sync/skip/test_stash_pop', - 'distributed/pipeline/sync/skip/test_tracker', - 'distributed/pipeline/sync/skip/test_verify_skippables', - 'distributed/pipeline/sync/test_balance', - 'distributed/pipeline/sync/test_bugs', - 'distributed/pipeline/sync/test_checkpoint', - 'distributed/pipeline/sync/test_copy', - 'distributed/pipeline/sync/test_deferred_batch_norm', - 'distributed/pipeline/sync/test_dependency', - 'distributed/pipeline/sync/test_inplace', - 'distributed/pipeline/sync/test_microbatch', - 'distributed/pipeline/sync/test_phony', - 'distributed/pipeline/sync/test_pipe', - 'distributed/pipeline/sync/test_pipeline', - 'distributed/pipeline/sync/test_stream', - 'distributed/pipeline/sync/test_transparency', - 'distributed/pipeline/sync/test_worker', + "distributed/nn/jit/test_instantiator", + "distributed/rpc/test_faulty_agent", + "distributed/rpc/test_tensorpipe_agent", + "distributed/rpc/cuda/test_tensorpipe_agent", + "distributed/pipeline/sync/skip/test_api", + "distributed/pipeline/sync/skip/test_gpipe", + "distributed/pipeline/sync/skip/test_inspect_skip_layout", + "distributed/pipeline/sync/skip/test_leak", + "distributed/pipeline/sync/skip/test_portal", + "distributed/pipeline/sync/skip/test_stash_pop", + "distributed/pipeline/sync/skip/test_tracker", + "distributed/pipeline/sync/skip/test_verify_skippables", + "distributed/pipeline/sync/test_balance", + "distributed/pipeline/sync/test_bugs", + "distributed/pipeline/sync/test_checkpoint", + "distributed/pipeline/sync/test_copy", + "distributed/pipeline/sync/test_deferred_batch_norm", + "distributed/pipeline/sync/test_dependency", + "distributed/pipeline/sync/test_inplace", + "distributed/pipeline/sync/test_microbatch", + "distributed/pipeline/sync/test_phony", + "distributed/pipeline/sync/test_pipe", + "distributed/pipeline/sync/test_pipeline", + "distributed/pipeline/sync/test_stream", + "distributed/pipeline/sync/test_transparency", + "distributed/pipeline/sync/test_worker", "distributed/elastic/agent/server/test/api_test", - 'distributed/elastic/multiprocessing/api_test', - 'distributed/_sharded_tensor/test_sharded_tensor', + "distributed/elastic/multiprocessing/api_test", + "distributed/_sharded_tensor/test_sharded_tensor", ] ROCM_BLOCKLIST = [ - 'distributed/nn/jit/test_instantiator', - 'distributed/rpc/test_faulty_agent', - 'distributed/rpc/test_tensorpipe_agent', - 'distributed/rpc/cuda/test_tensorpipe_agent', - 'test_determination', - 'test_multiprocessing', - 'test_jit_legacy', - 'test_type_hints', - 'test_openmp', + "distributed/nn/jit/test_instantiator", + "distributed/rpc/test_faulty_agent", + "distributed/rpc/test_tensorpipe_agent", + "distributed/rpc/cuda/test_tensorpipe_agent", + "distributed/_sharded_tensor/test_sharded_tensor", + "test_determination", + "test_multiprocessing", + "test_jit_legacy", + "test_type_hints", + "test_openmp", ] RUN_PARALLEL_BLOCKLIST = [ - 'test_cpp_extensions_jit', - 'test_jit_disabled', - 'test_mobile_optimizer', - 'test_multiprocessing', - 'test_multiprocessing_spawn', - 'test_namedtuple_return_api', - 'test_overrides', - 'test_show_pickle', - 'test_tensorexpr', - 'test_cuda_primary_ctx', -] + [test for test in TESTS if test.startswith('distributed/')] - -WINDOWS_COVERAGE_BLOCKLIST = [ -] - - -# These tests are slow enough that it's worth calculating whether the patch -# touched any related files first. This list was manually generated, but for every -# run with --determine-from, we use another generated list based on this one and the -# previous test stats. -TARGET_DET_LIST = [ - 'distributions/test_distributions', - 'test_nn', - 'test_autograd', - 'test_cpp_extensions_jit', - 'test_jit_legacy', - 'test_dataloader', - 'test_overrides', - 'test_linalg', - 'test_jit', - 'test_jit_profiling', - 'test_torch', - 'test_binary_ufuncs', - 'test_numpy_interop', - 'test_reductions', - 'test_shape_ops', - 'test_sort_and_select', - 'test_testing', - 'test_view_ops', - 'distributed/nn/jit/test_instantiator', - 'distributed/test_distributed_fork', - 'distributed/rpc/test_tensorpipe_agent', - 'distributed/rpc/cuda/test_tensorpipe_agent', - 'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks', - 'distributed/test_distributed_spawn', - 'test_cuda', - 'test_cuda_primary_ctx', - 'test_cpp_extensions_aot_ninja', - 'test_cpp_extensions_aot_no_ninja', - 'test_serialization', - 'test_optim', - 'test_utils', - 'test_multiprocessing', - 'test_tensorboard', - 'distributed/test_c10d_common', - 'distributed/test_c10d_gloo', - 'distributed/test_c10d_nccl', - 'distributed/test_jit_c10d', - 'distributed/test_c10d_spawn_gloo', - 'distributed/test_c10d_spawn_nccl', - 'distributed/test_store', - 'distributed/test_pg_wrapper', - 'test_quantization', - 'test_pruning_op', - 'test_determination', - 'test_futures', - 'distributed/pipeline/sync/skip/test_api', - 'distributed/pipeline/sync/skip/test_gpipe', - 'distributed/pipeline/sync/skip/test_inspect_skip_layout', - 'distributed/pipeline/sync/skip/test_leak', - 'distributed/pipeline/sync/skip/test_portal', - 'distributed/pipeline/sync/skip/test_stash_pop', - 'distributed/pipeline/sync/skip/test_tracker', - 'distributed/pipeline/sync/skip/test_verify_skippables', - 'distributed/pipeline/sync/test_balance', - 'distributed/pipeline/sync/test_bugs', - 'distributed/pipeline/sync/test_checkpoint', - 'distributed/pipeline/sync/test_copy', - 'distributed/pipeline/sync/test_deferred_batch_norm', - 'distributed/pipeline/sync/test_dependency', - 'distributed/pipeline/sync/test_inplace', - 'distributed/pipeline/sync/test_microbatch', - 'distributed/pipeline/sync/test_phony', - 'distributed/pipeline/sync/test_pipe', - 'distributed/pipeline/sync/test_pipeline', - 'distributed/pipeline/sync/test_stream', - 'distributed/pipeline/sync/test_transparency', - 'distributed/pipeline/sync/test_worker', + "test_cpp_extensions_jit", + "test_jit_disabled", + "test_mobile_optimizer", + "test_multiprocessing", + "test_multiprocessing_spawn", + "test_namedtuple_return_api", + "test_overrides", + "test_show_pickle", + "test_tensorexpr", + "test_cuda_primary_ctx", +] + [test for test in TESTS if test.startswith("distributed/")] + +WINDOWS_COVERAGE_BLOCKLIST = [] + +# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected +CORE_TEST_LIST = [ + "test_autograd", + "test_modules", + "test_nn", + "test_ops", + "test_torch" ] # the JSON file to store the S3 test stats -TEST_TIMES_FILE = '.pytorch-test-times.json' +TEST_TIMES_FILE = ".pytorch-test-times.json" # if a test file takes longer than 5 min, we add it to TARGET_DET_LIST SLOW_TEST_THRESHOLD = 300 -_DEP_MODULES_CACHE: Dict[str, set] = {} - DISTRIBUTED_TESTS_CONFIG = {} if dist.is_available(): - DISTRIBUTED_TESTS_CONFIG['test'] = { - 'WORLD_SIZE': '1' - } + DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"} if not TEST_WITH_ROCM and dist.is_mpi_available(): - DISTRIBUTED_TESTS_CONFIG['mpi'] = { - 'WORLD_SIZE': '3', - 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-mpi' + DISTRIBUTED_TESTS_CONFIG["mpi"] = { + "WORLD_SIZE": "3", + "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi", } if dist.is_nccl_available(): - DISTRIBUTED_TESTS_CONFIG['nccl'] = { - 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', - 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-nccl' + DISTRIBUTED_TESTS_CONFIG["nccl"] = { + "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", + "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl", } if dist.is_gloo_available(): - DISTRIBUTED_TESTS_CONFIG['gloo'] = { - 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3', - 'TEST_REPORT_SOURCE_OVERRIDE': 'dist-gloo' + DISTRIBUTED_TESTS_CONFIG["gloo"] = { + "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", + "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo", } # https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python -SIGNALS_TO_NAMES_DICT = {getattr(signal, n): n for n in dir(signal) - if n.startswith('SIG') and '_' not in n} +SIGNALS_TO_NAMES_DICT = { + getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n +} CPP_EXTENSIONS_ERROR = """ Ninja (https://ninja-build.org) is required for some of the C++ extensions @@ -387,13 +274,68 @@ PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE")) -ENABLE_PR_HISTORY_REORDERING = bool(os.environ.get("ENABLE_PR_HISTORY_REORDERING", "0") == "1") +ENABLE_PR_HISTORY_REORDERING = bool( + os.environ.get("ENABLE_PR_HISTORY_REORDERING", "0") == "1" +) JIT_EXECUTOR_TESTS = [ - 'test_jit_cuda_fuser', - 'test_jit_profiling', - 'test_jit_legacy', - 'test_jit_fuser_legacy', + "test_jit_cuda_fuser", + "test_jit_profiling", + "test_jit_legacy", + "test_jit_fuser_legacy", +] + +DISTRIBUTED_TESTS = [ + "distributed/test_data_parallel", + "distributed/test_launcher", + "distributed/nn/jit/test_instantiator", + "distributed/rpc/test_faulty_agent", + "distributed/rpc/test_tensorpipe_agent", + "distributed/rpc/cuda/test_tensorpipe_agent", + "distributed/test_c10d_common", + "distributed/test_c10d_gloo", + "distributed/test_c10d_nccl", + "distributed/test_jit_c10d", + "distributed/test_c10d_spawn_gloo", + "distributed/test_c10d_spawn_nccl", + "distributed/test_store", + "distributed/test_pg_wrapper", + "distributed/algorithms/test_join", + "distributed/test_distributed_spawn", + "distributed/pipeline/sync/skip/test_api", + "distributed/pipeline/sync/skip/test_gpipe", + "distributed/pipeline/sync/skip/test_inspect_skip_layout", + "distributed/pipeline/sync/skip/test_leak", + "distributed/pipeline/sync/skip/test_portal", + "distributed/pipeline/sync/skip/test_stash_pop", + "distributed/pipeline/sync/skip/test_tracker", + "distributed/pipeline/sync/skip/test_verify_skippables", + "distributed/pipeline/sync/test_balance", + "distributed/pipeline/sync/test_bugs", + "distributed/pipeline/sync/test_checkpoint", + "distributed/pipeline/sync/test_copy", + "distributed/pipeline/sync/test_deferred_batch_norm", + "distributed/pipeline/sync/test_dependency", + "distributed/pipeline/sync/test_inplace", + "distributed/pipeline/sync/test_microbatch", + "distributed/pipeline/sync/test_phony", + "distributed/pipeline/sync/test_pipe", + "distributed/pipeline/sync/test_pipeline", + "distributed/pipeline/sync/test_stream", + "distributed/pipeline/sync/test_transparency", + "distributed/pipeline/sync/test_worker", + "distributed/optim/test_zero_redundancy_optimizer", + "distributed/elastic/timer/api_test", + "distributed/elastic/timer/local_timer_example", + "distributed/elastic/timer/local_timer_test", + "distributed/elastic/events/lib_test", + "distributed/elastic/metrics/api_test", + "distributed/elastic/utils/logging_test", + "distributed/elastic/utils/util_test", + "distributed/elastic/utils/distributed_test", + "distributed/elastic/multiprocessing/api_test", + "distributed/_sharding_spec/test_sharding_spec", + "distributed/_sharded_tensor/test_sharded_tensor", ] # Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when @@ -408,7 +350,7 @@ # The file from which the SPECIFIED_TEST_CASES_DICT will be filled, a CSV of test cases that would be run when # options.run_specified_test_cases is enabled. -SPECIFIED_TEST_CASES_FILE: str = '.pytorch_specified_test_cases.csv' +SPECIFIED_TEST_CASES_FILE: str = ".pytorch_specified_test_cases.csv" def print_to_stderr(message): @@ -418,15 +360,18 @@ def print_to_stderr(message): def get_test_case_args(test_module, using_pytest) -> List[str]: args = [] # if test_module not specified or specified with '__all__' then run all tests - if test_module not in SPECIFIED_TEST_CASES_DICT or '__all__' in SPECIFIED_TEST_CASES_DICT[test_module]: + if ( + test_module not in SPECIFIED_TEST_CASES_DICT + or "__all__" in SPECIFIED_TEST_CASES_DICT[test_module] + ): return args if using_pytest: - args.append('-k') - args.append(' or '.join(SPECIFIED_TEST_CASES_DICT[test_module])) + args.append("-k") + args.append(" or ".join(SPECIFIED_TEST_CASES_DICT[test_module])) else: for test in SPECIFIED_TEST_CASES_DICT[test_module]: - args.append('-k') + args.append("-k") args.append(test) return args @@ -434,59 +379,70 @@ def get_test_case_args(test_module, using_pytest) -> List[str]: def get_executable_command(options, allow_pytest, disable_coverage=False): if options.coverage and not disable_coverage: - executable = ['coverage', 'run', '--parallel-mode', '--source=torch'] + executable = ["coverage", "run", "--parallel-mode", "--source=torch"] else: executable = [sys.executable] if options.pytest: if allow_pytest: - executable += ['-m', 'pytest'] + executable += ["-m", "pytest"] else: - print_to_stderr('Pytest cannot be used for this test. Falling back to unittest.') + print_to_stderr( + "Pytest cannot be used for this test. Falling back to unittest." + ) return executable -def run_test(test_module, test_directory, options, launcher_cmd=None, extra_unittest_args=None): +def run_test( + test_module, test_directory, options, launcher_cmd=None, extra_unittest_args=None +): unittest_args = options.additional_unittest_args.copy() if options.verbose: unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest if test_module in RUN_PARALLEL_BLOCKLIST: - unittest_args = [arg for arg in unittest_args if not arg.startswith('--run-parallel')] + unittest_args = [ + arg for arg in unittest_args if not arg.startswith("--run-parallel") + ] if extra_unittest_args: assert isinstance(extra_unittest_args, list) unittest_args.extend(extra_unittest_args) # If using pytest, replace -f with equivalent -x if options.pytest: - unittest_args = [arg if arg != '-f' else '-x' for arg in unittest_args] + unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args] elif IS_IN_CI: # use the downloaded test cases configuration, not supported in pytest - unittest_args.extend(['--import-slow-tests', '--import-disabled-tests']) + unittest_args.extend(["--import-slow-tests", "--import-disabled-tests"]) # Multiprocessing related tests cannot run with coverage. # Tracking issue: https://github.com/pytorch/pytorch/issues/50661 - disable_coverage = sys.platform == 'win32' and test_module in WINDOWS_COVERAGE_BLOCKLIST + disable_coverage = ( + sys.platform == "win32" and test_module in WINDOWS_COVERAGE_BLOCKLIST + ) # Extra arguments are not supported with pytest - executable = get_executable_command(options, allow_pytest=not extra_unittest_args, - disable_coverage=disable_coverage) + executable = get_executable_command( + options, allow_pytest=not extra_unittest_args, disable_coverage=disable_coverage + ) # TODO: move this logic into common_utils.py instead of passing in "-k" individually # The following logic for running specified tests will only run for non-distributed tests, as those are dispatched # to test_distributed and not run_test (this function) if options.run_specified_test_cases: - unittest_args.extend(get_test_case_args(test_module, 'pytest' in executable)) + unittest_args.extend(get_test_case_args(test_module, "pytest" in executable)) # Can't call `python -m unittest test_*` here because it doesn't run code # in `if __name__ == '__main__': `. So call `python test_*.py` instead. - argv = [test_module + '.py'] + unittest_args + argv = [test_module + ".py"] + unittest_args command = (launcher_cmd or []) + executable + argv - print_to_stderr('Executing {} ... [{}]'.format(command, datetime.now())) + print_to_stderr("Executing {} ... [{}]".format(command, datetime.now())) return shell(command, test_directory) def test_cuda_primary_ctx(test_module, test_directory, options): - return run_test(test_module, test_directory, options, extra_unittest_args=['--subprocess']) + return run_test( + test_module, test_directory, options, extra_unittest_args=["--subprocess"] + ) def _test_cpp_extensions_aot(test_directory, options, use_ninja): @@ -498,46 +454,52 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): return 1 # Wipe the build folder, if it exists already - cpp_extensions_test_dir = os.path.join(test_directory, 'cpp_extensions') - cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, 'build') + cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions") + cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build") if os.path.exists(cpp_extensions_test_build_dir): shutil.rmtree(cpp_extensions_test_build_dir) # Build the test cpp extensions modules shell_env = os.environ.copy() - shell_env['USE_NINJA'] = str(1 if use_ninja else 0) - cmd = [sys.executable, 'setup.py', 'install', '--root', './install'] + shell_env["USE_NINJA"] = str(1 if use_ninja else 0) + cmd = [sys.executable, "setup.py", "install", "--root", "./install"] return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env) if return_code != 0: return return_code - if sys.platform != 'win32': - return_code = shell(cmd, - cwd=os.path.join(cpp_extensions_test_dir, 'no_python_abi_suffix_test'), - env=shell_env) + if sys.platform != "win32": + return_code = shell( + cmd, + cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"), + env=shell_env, + ) if return_code != 0: return return_code # "install" the test modules and run tests - python_path = os.environ.get('PYTHONPATH', '') + python_path = os.environ.get("PYTHONPATH", "") from shutil import copyfile - test_module = 'test_cpp_extensions_aot' + ('_ninja' if use_ninja else '_no_ninja') - copyfile(test_directory + '/test_cpp_extensions_aot.py', test_directory + '/' + test_module + '.py') + + test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja") + copyfile( + test_directory + "/test_cpp_extensions_aot.py", + test_directory + "/" + test_module + ".py", + ) try: - cpp_extensions = os.path.join(test_directory, 'cpp_extensions') - install_directory = '' + cpp_extensions = os.path.join(test_directory, "cpp_extensions") + install_directory = "" # install directory is the one that is named site-packages - for root, directories, _ in os.walk(os.path.join(cpp_extensions, 'install')): + for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")): for directory in directories: - if '-packages' in directory: + if "-packages" in directory: install_directory = os.path.join(root, directory) - assert install_directory, 'install_directory must not be empty' - os.environ['PYTHONPATH'] = os.pathsep.join([install_directory, python_path]) + assert install_directory, "install_directory must not be empty" + os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path]) return run_test(test_module, test_directory, options) finally: - os.environ['PYTHONPATH'] = python_path - if os.path.exists(test_directory + '/' + test_module + '.py'): - os.remove(test_directory + '/' + test_module + '.py') + os.environ["PYTHONPATH"] = python_path + if os.path.exists(test_directory + "/" + test_module + ".py"): + os.remove(test_directory + "/" + test_module + ".py") def test_cpp_extensions_aot_ninja(test_module, test_directory, options): @@ -550,53 +512,73 @@ def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options): def test_distributed(test_module, test_directory, options): # MPI tests are broken with Python-3.9 - mpi_available = subprocess.call('command -v mpiexec', shell=True) == 0 and sys.version_info < (3, 9) + mpi_available = subprocess.call( + "command -v mpiexec", shell=True + ) == 0 and sys.version_info < (3, 9) if options.verbose and not mpi_available: - print_to_stderr( - 'MPI not available -- MPI backend tests will be skipped') + print_to_stderr("MPI not available -- MPI backend tests will be skipped") config = DISTRIBUTED_TESTS_CONFIG for backend, env_vars in config.items(): - if sys.platform == 'win32' and backend != 'gloo': + if sys.platform == "win32" and backend != "gloo": continue - if backend == 'mpi' and not mpi_available: + if backend == "mpi" and not mpi_available: continue for with_init_file in {True, False}: - if sys.platform == 'win32' and not with_init_file: + if sys.platform == "win32" and not with_init_file: continue tmp_dir = tempfile.mkdtemp() if options.verbose: init_str = "with {} init_method" with_init = init_str.format("file" if with_init_file else "env") print_to_stderr( - 'Running distributed tests for the {} backend {}'.format( - backend, with_init)) - os.environ['TEMP_DIR'] = tmp_dir - os.environ['BACKEND'] = backend - os.environ['INIT_METHOD'] = 'env://' + "Running distributed tests for the {} backend {}".format( + backend, with_init + ) + ) + os.environ["TEMP_DIR"] = tmp_dir + os.environ["BACKEND"] = backend + os.environ["INIT_METHOD"] = "env://" os.environ.update(env_vars) if with_init_file: - if test_module in ["test_distributed_fork", "test_distributed_spawn"]: - init_method = f'{FILE_SCHEMA}{tmp_dir}/' + if test_module == "test_distributed_spawn": + init_method = f"{FILE_SCHEMA}{tmp_dir}/" else: - init_method = f'{FILE_SCHEMA}{tmp_dir}/shared_init_file' - os.environ['INIT_METHOD'] = init_method + init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file" + os.environ["INIT_METHOD"] = init_method try: - os.mkdir(os.path.join(tmp_dir, 'barrier')) - os.mkdir(os.path.join(tmp_dir, 'test_dir')) - if backend == 'mpi': + os.mkdir(os.path.join(tmp_dir, "barrier")) + os.mkdir(os.path.join(tmp_dir, "test_dir")) + if backend == "mpi": # test mpiexec for --noprefix option - with open(os.devnull, 'w') as devnull: - allowrunasroot_opt = '--allow-run-as-root' if subprocess.call( - 'mpiexec --allow-run-as-root -n 1 bash -c ""', shell=True, - stdout=devnull, stderr=subprocess.STDOUT) == 0 else '' - noprefix_opt = '--noprefix' if subprocess.call( - f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""', shell=True, - stdout=devnull, stderr=subprocess.STDOUT) == 0 else '' - - mpiexec = ['mpiexec', '-n', '3', noprefix_opt, allowrunasroot_opt] - - return_code = run_test(test_module, test_directory, options, - launcher_cmd=mpiexec) + with open(os.devnull, "w") as devnull: + allowrunasroot_opt = ( + "--allow-run-as-root" + if subprocess.call( + 'mpiexec --allow-run-as-root -n 1 bash -c ""', + shell=True, + stdout=devnull, + stderr=subprocess.STDOUT, + ) + == 0 + else "" + ) + noprefix_opt = ( + "--noprefix" + if subprocess.call( + f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""', + shell=True, + stdout=devnull, + stderr=subprocess.STDOUT, + ) + == 0 + else "" + ) + + mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt] + + return_code = run_test( + test_module, test_directory, options, launcher_cmd=mpiexec + ) else: return_code = run_test(test_module, test_directory, options) if return_code != 0: @@ -607,16 +589,15 @@ def test_distributed(test_module, test_directory, options): CUSTOM_HANDLERS = { - 'test_cuda_primary_ctx': test_cuda_primary_ctx, - 'test_cpp_extensions_aot_no_ninja': test_cpp_extensions_aot_no_ninja, - 'test_cpp_extensions_aot_ninja': test_cpp_extensions_aot_ninja, - 'distributed/test_distributed_fork': test_distributed, - 'distributed/test_distributed_spawn': test_distributed, + "test_cuda_primary_ctx": test_cuda_primary_ctx, + "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja, + "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja, + "distributed/test_distributed_spawn": test_distributed, } def parse_test_module(test): - return test.split('.')[0] + return test.split(".")[0] class TestChoices(list): @@ -629,127 +610,159 @@ def __contains__(self, item): def parse_args(): parser = argparse.ArgumentParser( - description='Run the PyTorch unit test suite', - epilog='where TESTS is any of: {}'.format(', '.join(TESTS)), - formatter_class=argparse.RawTextHelpFormatter) + description="Run the PyTorch unit test suite", + epilog="where TESTS is any of: {}".format(", ".join(TESTS)), + formatter_class=argparse.RawTextHelpFormatter, + ) parser.add_argument( - '-v', - '--verbose', - action='count', + "-v", + "--verbose", + action="count", default=0, - help='print verbose information and test-by-test results') + help="print verbose information and test-by-test results", + ) + parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests") + parser.add_argument( + "--distributed-tests", + "--distributed-tests", + action="store_true", + help="run all distributed tests", + ) parser.add_argument( - '--jit', - '--jit', - action='store_true', - help='run all jit tests') + "-core", + "--core", + action="store_true", + help="Only run core tests, or tests that validate PyTorch's ops, modules," + "and autograd. They are defined by CORE_TEST_LIST." + ) parser.add_argument( - '-pt', '--pytest', action='store_true', - help='If true, use `pytest` to execute the tests. E.g., this runs ' - 'TestTorch with pytest in verbose and coverage mode: ' - 'python run_test.py -vci torch -pt') + "-pt", + "--pytest", + action="store_true", + help="If true, use `pytest` to execute the tests. E.g., this runs " + "TestTorch with pytest in verbose and coverage mode: " + "python run_test.py -vci torch -pt", + ) parser.add_argument( - '-c', '--coverage', action='store_true', help='enable coverage', - default=PYTORCH_COLLECT_COVERAGE) + "-c", + "--coverage", + action="store_true", + help="enable coverage", + default=PYTORCH_COLLECT_COVERAGE, + ) parser.add_argument( - '-i', - '--include', - nargs='+', + "-i", + "--include", + nargs="+", choices=TestChoices(TESTS), default=TESTS, - metavar='TESTS', - help='select a set of tests to include (defaults to ALL tests).' - ' tests must be a part of the TESTS list defined in run_test.py') + metavar="TESTS", + help="select a set of tests to include (defaults to ALL tests)." + " tests must be a part of the TESTS list defined in run_test.py", + ) parser.add_argument( - '-x', - '--exclude', - nargs='+', + "-x", + "--exclude", + nargs="+", choices=TESTS, - metavar='TESTS', + metavar="TESTS", default=[], - help='select a set of tests to exclude') + help="select a set of tests to exclude", + ) parser.add_argument( - '-f', - '--first', + "-f", + "--first", choices=TESTS, - metavar='TESTS', - help='select the test to start from (excludes previous tests)') + metavar="TESTS", + help="select the test to start from (excludes previous tests)", + ) parser.add_argument( - '-l', - '--last', + "-l", + "--last", choices=TESTS, - metavar='TESTS', - help='select the last test to run (excludes following tests)') + metavar="TESTS", + help="select the last test to run (excludes following tests)", + ) parser.add_argument( - '--bring-to-front', - nargs='+', + "--bring-to-front", + nargs="+", choices=TestChoices(TESTS), default=[], - metavar='TESTS', - help='select a set of tests to run first. This can be used in situations' - ' where you want to run all tests, but care more about some set, ' - 'e.g. after making a change to a specific component') + metavar="TESTS", + help="select a set of tests to run first. This can be used in situations" + " where you want to run all tests, but care more about some set, " + "e.g. after making a change to a specific component", + ) parser.add_argument( - '--ignore-win-blocklist', - action='store_true', - help='always run blocklisted windows tests') + "--ignore-win-blocklist", + action="store_true", + help="always run blocklisted windows tests", + ) parser.add_argument( - '--determine-from', - help='File of affected source filenames to determine which tests to run.') + "--determine-from", + help="File of affected source filenames to determine which tests to run.", + ) parser.add_argument( - '--continue-through-error', - action='store_true', - help='Runs the full test suite despite one of the tests failing', - default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False"))) + "--continue-through-error", + action="store_true", + help="Runs the full test suite despite one of the tests failing", + default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")), + ) parser.add_argument( - 'additional_unittest_args', - nargs='*', - help='additional arguments passed through to unittest, e.g., ' - 'python run_test.py -i sparse -- TestSparse.test_factory_size_check') + "additional_unittest_args", + nargs="*", + help="additional arguments passed through to unittest, e.g., " + "python run_test.py -i sparse -- TestSparse.test_factory_size_check", + ) parser.add_argument( - '--export-past-test-times', - nargs='?', + "--export-past-test-times", + nargs="?", type=str, const=TEST_TIMES_FILE, - help='dumps test times from previous S3 stats into a file, format JSON', + help="dumps test times from previous S3 stats into a file, format JSON", ) parser.add_argument( - '--shard', + "--shard", nargs=2, type=int, - help='runs a shard of the tests (taking into account other selections), e.g., ' - '--shard 2 3 will break up the selected tests into 3 shards and run the tests ' - 'in the 2nd shard (the first number should not exceed the second)', + help="runs a shard of the tests (taking into account other selections), e.g., " + "--shard 2 3 will break up the selected tests into 3 shards and run the tests " + "in the 2nd shard (the first number should not exceed the second)", ) parser.add_argument( - '--exclude-jit-executor', - action='store_true', - help='exclude tests that are run for a specific jit config' + "--exclude-jit-executor", + action="store_true", + help="exclude tests that are run for a specific jit config", ) parser.add_argument( - '--run-specified-test-cases', - nargs='?', + "--exclude-distributed-tests", + action="store_true", + help="exclude distributed tests", + ) + parser.add_argument( + "--run-specified-test-cases", + nargs="?", type=str, const=SPECIFIED_TEST_CASES_FILE, - help='load specified test cases file dumped from previous OSS CI stats, format CSV. ' - ' If all test cases should run for a please add a single row: \n' - ' test_filename,test_case_name\n' - ' ...\n' - ' ,__all__\n' - ' ...\n' - 'how we use the stats will be based on option "--use-specified-test-cases-by".' + help="load specified test cases file dumped from previous OSS CI stats, format CSV. " + " If all test cases should run for a please add a single row: \n" + " test_filename,test_case_name\n" + " ...\n" + " ,__all__\n" + " ...\n" + 'how we use the stats will be based on option "--use-specified-test-cases-by".', ) parser.add_argument( - '--use-specified-test-cases-by', + "--use-specified-test-cases-by", type=str, - choices=['include', 'bring-to-front'], - default='include', + choices=["include", "bring-to-front"], + default="include", help='used together with option "--run-specified-test-cases". When specified test case ' - 'file is set, this option allows the user to control whether to only run the specified test ' - 'modules or to simply bring the specified modules to front and also run the remaining ' - 'modules. Note: regardless of this option, we will only run the specified test cases ' - ' within a specified test module. For unspecified test modules with the bring-to-front ' - 'option, all test cases will be run, as one may expect.', + "file is set, this option allows the user to control whether to only run the specified test " + "modules or to simply bring the specified modules to front and also run the remaining " + "modules. Note: regardless of this option, we will only run the specified test cases " + " within a specified test module. For unspecified test modules with the bring-to-front " + "option, all test cases will be run, as one may expect.", ) return parser.parse_args() @@ -797,24 +810,44 @@ def exclude_tests(exclude_list, selected_tests, exclude_message=None): for test in tests_copy: if test.startswith(exclude_test): if exclude_message is not None: - print_to_stderr('Excluding {} {}'.format(test, exclude_message)) + print_to_stderr("Excluding {} {}".format(test, exclude_message)) selected_tests.remove(test) return selected_tests def get_selected_tests(options): + # First make sure run specific test cases options are processed. if options.run_specified_test_cases: - if options.use_specified_test_cases_by == 'include': + if options.use_specified_test_cases_by == "include": options.include = list(SPECIFIED_TEST_CASES_DICT.keys()) - elif options.use_specified_test_cases_by == 'bring-to-front': + elif options.use_specified_test_cases_by == "bring-to-front": options.bring_to_front = list(SPECIFIED_TEST_CASES_DICT.keys()) selected_tests = options.include + # filter if there's JIT only and distributed only test options + if options.jit: + selected_tests = list( + filter(lambda test_name: "jit" in test_name, selected_tests) + ) + + if options.distributed_tests: + selected_tests = list( + filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests) + ) + + # Filter to only run core tests when --core option is specified + if options.core: + selected_tests = list( + filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests) + ) + + # process reordering if options.bring_to_front: to_front = set(options.bring_to_front) - selected_tests = options.bring_to_front + list(filter(lambda name: name not in to_front, - selected_tests)) + selected_tests = options.bring_to_front + list( + filter(lambda name: name not in to_front, selected_tests) + ) if options.first: first_index = find_test_index(options.first, selected_tests) @@ -822,189 +855,70 @@ def get_selected_tests(options): if options.last: last_index = find_test_index(options.last, selected_tests, find_last_index=True) - selected_tests = selected_tests[:last_index + 1] + selected_tests = selected_tests[: last_index + 1] + # process exclusion if options.exclude_jit_executor: options.exclude.extend(JIT_EXECUTOR_TESTS) + if options.exclude_distributed_tests: + options.exclude.extend(DISTRIBUTED_TESTS) + selected_tests = exclude_tests(options.exclude, selected_tests) - if sys.platform == 'win32' and not options.ignore_win_blocklist: - target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH') - if target_arch != 'x64': - WINDOWS_BLOCKLIST.append('cpp_extensions_aot_no_ninja') - WINDOWS_BLOCKLIST.append('cpp_extensions_aot_ninja') - WINDOWS_BLOCKLIST.append('cpp_extensions_jit') - WINDOWS_BLOCKLIST.append('jit') - WINDOWS_BLOCKLIST.append('jit_fuser') + if sys.platform == "win32" and not options.ignore_win_blocklist: + target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH") + if target_arch != "x64": + WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja") + WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja") + WINDOWS_BLOCKLIST.append("cpp_extensions_jit") + WINDOWS_BLOCKLIST.append("jit") + WINDOWS_BLOCKLIST.append("jit_fuser") - selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, 'on Windows') + selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows") elif TEST_WITH_ROCM: - selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, 'on ROCm') + selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm") + # sharding if options.shard: assert len(options.shard) == 2, "Unexpected shard format" assert min(options.shard) > 0, "Shards must be positive numbers" which_shard, num_shards = options.shard - assert which_shard <= num_shards, "Selected shard must be less than or equal to total number of shards" - assert num_shards <= len(selected_tests), f"Number of shards must be less than {len(selected_tests)}" + assert ( + which_shard <= num_shards + ), "Selected shard must be less than or equal to total number of shards" + assert num_shards <= len( + selected_tests + ), f"Number of shards must be less than {len(selected_tests)}" # TODO: fix this to use test_times_filename, but currently this is not working # because setting the export arg immeidately halts the test execution. - selected_tests = get_shard_based_on_S3(which_shard, num_shards, selected_tests, TEST_TIMES_FILE) - - return selected_tests - - -def test_impact_of_file(filename): - """Determine what class of impact this file has on test runs. - - Possible values: - TORCH - torch python code - CAFFE2 - caffe2 python code - TEST - torch test code - UNKNOWN - may affect all tests - NONE - known to have no effect on test outcome - CI - CI configuration files - """ - parts = filename.split(os.sep) - if parts[0] in ['.jenkins', '.circleci']: - return 'CI' - if parts[0] in ['docs', 'scripts', 'CODEOWNERS', 'README.md']: - return 'NONE' - elif parts[0] == 'torch': - if parts[-1].endswith('.py') or parts[-1].endswith('.pyi'): - return 'TORCH' - elif parts[0] == 'caffe2': - if parts[-1].endswith('.py') or parts[-1].endswith('.pyi'): - return 'CAFFE2' - elif parts[0] == 'test': - if parts[-1].endswith('.py') or parts[-1].endswith('.pyi'): - return 'TEST' - - return 'UNKNOWN' - - -def log_test_reason(file_type, filename, test, options): - if options.verbose: - print_to_stderr( - 'Determination found {} file {} -- running {}'.format( - file_type, - filename, - test, - ) + selected_tests = get_shard_based_on_S3( + which_shard, num_shards, selected_tests, TEST_TIMES_FILE ) - -def get_dep_modules(test): - # Cache results in case of repetition - if test in _DEP_MODULES_CACHE: - return _DEP_MODULES_CACHE[test] - - repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - test_location = os.path.join(repo_root, 'test', test + '.py') - finder = modulefinder.ModuleFinder( - # Ideally exclude all third party modules, to speed up calculation. - excludes=[ - 'scipy', - 'numpy', - 'numba', - 'multiprocessing', - 'sklearn', - 'setuptools', - 'hypothesis', - 'llvmlite', - 'joblib', - 'email', - 'importlib', - 'unittest', - 'urllib', - 'json', - 'collections', - # Modules below are excluded because they are hitting https://bugs.python.org/issue40350 - # Trigger AttributeError: 'NoneType' object has no attribute 'is_package' - 'mpl_toolkits', - 'google', - 'onnx', - # Triggers RecursionError - 'mypy' - ], - ) - # HACK: some platforms default to ascii, so we can't just run_script :( - with open(test_location, 'r', encoding='utf-8') as fp: - finder.load_module('__main__', fp, test_location, ('', 'r', 1)) - - dep_modules = set(finder.modules.keys()) - _DEP_MODULES_CACHE[test] = dep_modules - return dep_modules - - -def determine_target(target_det_list, test, touched_files, options): - test = parse_test_module(test) - # Some tests are faster to execute than to determine. - if test not in target_det_list: - if options.verbose: - print_to_stderr(f'Running {test} without determination') - return True - # HACK: "no_ninja" is not a real module - if test.endswith('_no_ninja'): - test = test[:(-1 * len('_no_ninja'))] - if test.endswith('_ninja'): - test = test[:(-1 * len('_ninja'))] - - dep_modules = get_dep_modules(test) - - for touched_file in touched_files: - file_type = test_impact_of_file(touched_file) - if file_type == 'NONE': - continue - elif file_type == 'CI': - # Force all tests to run if any change is made to the CI - # configurations. - log_test_reason(file_type, touched_file, test, options) - return True - elif file_type == 'UNKNOWN': - # Assume uncategorized source files can affect every test. - log_test_reason(file_type, touched_file, test, options) - return True - elif file_type in ['TORCH', 'CAFFE2', 'TEST']: - parts = os.path.splitext(touched_file)[0].split(os.sep) - touched_module = ".".join(parts) - # test/ path does not have a "test." namespace - if touched_module.startswith('test.'): - touched_module = touched_module.split('test.')[1] - if ( - touched_module in dep_modules - or touched_module == test.replace('/', '.') - ): - log_test_reason(file_type, touched_file, test, options) - return True - - # If nothing has determined the test has run, don't run the test. - if options.verbose: - print_to_stderr(f'Determination is skipping {test}') - - return False + return selected_tests def run_test_module(test: str, test_directory: str, options) -> Optional[str]: test_module = parse_test_module(test) # Printing the date here can help diagnose which tests are slow - print_to_stderr('Running {} ... [{}]'.format(test, datetime.now())) + print_to_stderr("Running {} ... [{}]".format(test, datetime.now())) handler = CUSTOM_HANDLERS.get(test_module, run_test) return_code = handler(test_module, test_directory, options) assert isinstance(return_code, int) and not isinstance( - return_code, bool), 'Return code should be an integer' + return_code, bool + ), "Return code should be an integer" if return_code == 0: return None - message = f'{test} failed!' + message = f"{test} failed!" if return_code < 0: # subprocess.Popen returns the child process' exit signal as # return code -N, where N is the signal number. signal_name = SIGNALS_TO_NAMES_DICT[-return_code] - message += f' Received signal: {signal_name}' + message += f" Received signal: {signal_name}" return message @@ -1014,49 +928,62 @@ def main(): # TODO: move this export & download function in tools/ folder test_times_filename = options.export_past_test_times if test_times_filename: - print(f'Exporting past test times from S3 to {test_times_filename}, no tests will be run.') + print( + f"Exporting past test times from S3 to {test_times_filename}, no tests will be run." + ) export_S3_test_times(test_times_filename) return specified_test_cases_filename = options.run_specified_test_cases if specified_test_cases_filename: - print(f'Loading specified test cases to run from {specified_test_cases_filename}.') + print( + f"Loading specified test cases to run from {specified_test_cases_filename}." + ) global SPECIFIED_TEST_CASES_DICT - SPECIFIED_TEST_CASES_DICT = get_specified_test_cases(specified_test_cases_filename, TESTS) + SPECIFIED_TEST_CASES_DICT = get_specified_test_cases( + specified_test_cases_filename, TESTS + ) - test_directory = os.path.dirname(os.path.abspath(__file__)) + test_directory = str(REPO_ROOT / "test") selected_tests = get_selected_tests(options) if options.verbose: - print_to_stderr('Selected tests: {}'.format(', '.join(selected_tests))) + print_to_stderr("Selected tests: {}".format(", ".join(selected_tests))) if options.coverage and not PYTORCH_COLLECT_COVERAGE: - shell(['coverage', 'erase']) - - if options.jit: - selected_tests = filter(lambda test_name: "jit" in test_name, TESTS) + shell(["coverage", "erase"]) if options.determine_from is not None and os.path.exists(options.determine_from): - slow_tests = get_slow_tests_based_on_S3(TESTS, TARGET_DET_LIST, SLOW_TEST_THRESHOLD) - print('Added the following tests to target_det tests as calculated based on S3:') - print(slow_tests) - with open(options.determine_from, 'r') as fh: + slow_tests = get_slow_tests_based_on_S3( + TESTS, TARGET_DET_LIST, SLOW_TEST_THRESHOLD + ) + print_to_stderr( + "Added the following tests to target_det tests as calculated based on S3:" + ) + print_to_stderr(slow_tests) + with open(options.determine_from, "r") as fh: touched_files = [ - os.path.normpath(name.strip()) for name in fh.read().split('\n') + os.path.normpath(name.strip()) + for name in fh.read().split("\n") if len(name.strip()) > 0 ] # HACK: Ensure the 'test' paths can be traversed by Modulefinder - sys.path.append('test') + sys.path.append(test_directory) selected_tests = [ - test for test in selected_tests - if determine_target(TARGET_DET_LIST + slow_tests, test, touched_files, options) + test + for test in selected_tests + if should_run_test( + TARGET_DET_LIST + slow_tests, test, touched_files, options + ) ] - sys.path.remove('test') + sys.path.remove(test_directory) if IS_IN_CI: - selected_tests = get_reordered_tests(selected_tests, ENABLE_PR_HISTORY_REORDERING) + selected_tests = get_reordered_tests( + selected_tests, ENABLE_PR_HISTORY_REORDERING + ) # downloading test cases configuration to local environment - get_test_case_configs(dirpath=os.path.dirname(os.path.abspath(__file__))) + get_test_case_configs(dirpath=test_directory) has_failed = False failure_messages = [] @@ -1076,8 +1003,8 @@ def main(): finally: if options.coverage: from coverage import Coverage - test_dir = os.path.dirname(os.path.abspath(__file__)) - with set_cwd(test_dir): + + with set_cwd(test_directory): cov = Coverage() if PYTORCH_COLLECT_COVERAGE: cov.load() @@ -1091,5 +1018,6 @@ def main(): print_to_stderr(err) sys.exit(1) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/test/test_autograd.py b/test/test_autograd.py index 7200bd525acf2..e672e4b49e25e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -24,13 +24,13 @@ from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg) import torch.autograd.functional as autogradF from torch.utils.checkpoint import checkpoint +from torch.testing import make_tensor from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, suppress_warnings, slowTest, - load_tests, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, TEST_WITH_ROCM, disable_gc, - gradcheck, gradgradcheck, make_tensor) + gradcheck, gradgradcheck) from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction import torch.autograd.forward_ad as fwAD @@ -42,11 +42,7 @@ onlyCPU, onlyCUDA, onlyOnCPUAndCUDA, dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan, skipCUDAIf, skipMeta) - - -# load_tests from common_utils is used to automatically filter tests for -# sharding on sandcastle. This line silences flake warnings -load_tests = load_tests +from torch.testing._internal.common_dtype import get_all_dtypes import pickle @@ -2740,36 +2736,6 @@ def test_block_diag(self): lambda a, b, c: torch.block_diag(a, b, c), True, f_args_variable, f_args_tensor) - def test_cat(self): - f_args_variable = (torch.randn(1, S, S, dtype=torch.double, requires_grad=True), - torch.randn(2, S, S, dtype=torch.double, requires_grad=True), - torch.randn(3, S, S, dtype=torch.double, requires_grad=True), - 0) - f_args_tensor = deepcopy(unpack_variables(f_args_variable)) - run_functional_checks(self, "test_cat", "cat", - lambda a, b, c, dim: torch.cat((a, b, c), dim), - True, f_args_variable, f_args_tensor, check_forward_ad=True) - - def test_cat_negdim_1(self): - f_args_variable = (torch.randn(S, S, 1, dtype=torch.double, requires_grad=True), - torch.randn(S, S, 2, dtype=torch.double, requires_grad=True), - torch.randn(S, S, 3, dtype=torch.double, requires_grad=True), - -1) - f_args_tensor = deepcopy(unpack_variables(f_args_variable)) - run_functional_checks(self, "test_cat_negdim_1", "cat", - lambda a, b, c, dim: torch.cat((a, b, c), dim), - True, f_args_variable, f_args_tensor, check_forward_ad=True) - - def test_cat_negdim_2(self): - f_args_variable = (torch.randn(S, 1, S, dtype=torch.double, requires_grad=True), - torch.randn(S, 2, S, dtype=torch.double, requires_grad=True), - torch.randn(S, 3, S, dtype=torch.double, requires_grad=True), - -2) - f_args_tensor = deepcopy(unpack_variables(f_args_variable)) - run_functional_checks(self, "test_cat_negdim_2", "cat", - lambda a, b, c, dim: torch.cat((a, b, c), dim), - True, f_args_variable, f_args_tensor, check_forward_ad=True) - def test_cat_empty_legacy(self): f_args_variable = (torch.randn(0, dtype=torch.double, requires_grad=True), torch.randn(S, S, dtype=torch.double, requires_grad=True)) @@ -2781,14 +2747,6 @@ def test_cat_empty_legacy(self): False, f_args_variable, f_args_tensor, check_forward_ad=True) self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION)) - def test_cat_empty(self): - f_args_variable = (torch.randn(0, S, dtype=torch.double, requires_grad=True), - torch.randn(S, S, dtype=torch.double, requires_grad=True)) - f_args_tensor = deepcopy(unpack_variables(f_args_variable)) - run_functional_checks(self, "test_cat_empty", "cat", - lambda a, b: torch.cat((a, b)), - True, f_args_variable, f_args_tensor, check_forward_ad=True) - def test_var_mean_differentiable(self): dim = [2, 4] keepdim = False @@ -2801,11 +2759,11 @@ def test_var_mean_differentiable(self): r1 = var1 * var1 * mean1 * mean1 r2 = var2 * var2 * mean2 * mean2 - self.assertTrue(torch.allclose(r1, r2, rtol=0.01, atol=0.0)) + self.assertEqual(r1, r2, rtol=0.01, atol=0.0) torch.autograd.backward(r1, grad) torch.autograd.backward(r2, grad) - self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0)) + self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0) @slowTest @skipIfNoLapack @@ -3004,6 +2962,9 @@ def test_profiler_seq_nr(self): found_bwd_add = found_bwd_sum = False found_empty = False for e in p.function_events: + # Ignore record_function user scope. + if "autograd::engine::evaluate_function" in e.name: + continue if e.name == "aten::add": add_seq_nr = e.sequence_nr self.assertFalse(found_add) @@ -3452,7 +3413,7 @@ def test_inplace_on_view_backward(self): gradient_penalty.backward() fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0] - self.assertEqual(fn.name(), "ThresholdBackwardBackward") + self.assertEqual(fn.name(), "ThresholdBackwardBackward0") def test_inplace_on_view_weak_grad_fn(self): # Issue 23502: Test that b's grad_fn is preserved. @@ -4861,7 +4822,7 @@ def maybe_check_raise(fn, should_raise): # The 3 elements are for view_as, first output of unbind and second output of unbind run_test(grad_mode=True, requires_grad=False, is_view=True, should_raise_tuple=(None, None, None)) - inp_change_err = "Output {} of UnbindBackward is a view and is being modified inplace." + inp_change_err = "Output {} of UnbindBackward0 is a view and is being modified inplace." run_test(grad_mode=True, requires_grad=True, is_view=True, should_raise_tuple=(None, inp_change_err.format("0"), inp_change_err.format("1"))) leaf_grad_err = "A view was created in no_grad mode and is being modified inplace" @@ -5159,7 +5120,7 @@ def test_autograd_inplace_views_cross_dtype(self): # TODO: this is a bug! # once this is fixed, it should have the transpose removed: - # self.assertTrue(torch.allclose(non_inplace_grad, inplace_grad)) + # self.assertEqual(non_inplace_grad, inplace_grad) self.assertEqual(non_inplace_grad.T, inplace_grad) def test_autograd_multiple_views_python(self): @@ -5477,13 +5438,143 @@ class BadBw(Function): def forward(ctx, foo): return foo.clone() + class BadBw2(Function): + @staticmethod + def forward(ctx, foo): + return foo.clone() + + @staticmethod + def backward(ctx, foo): + return foo + + @staticmethod + def vjp(ctx, foo): + return foo + + class BadJvp(Function): + @staticmethod + def forward(ctx, foo): + return foo.clone() + inp = torch.rand(1, requires_grad=True) with self.assertRaisesRegex(NotImplementedError, "must implement the forward"): BadFw.apply(inp) - with self.assertRaisesRegex(RuntimeError, "must implement the backward"): + with self.assertRaisesRegex(RuntimeError, "must implement either the backward"): BadBw.apply(inp).sum().backward() + with self.assertRaisesRegex(RuntimeError, "Implementing both 'backward' and 'vjp'"): + BadBw2.apply(inp).sum().backward() + + with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"): + with fwAD.dual_level(): + d = fwAD.make_dual(inp, torch.rand_like(inp)) + res = BadJvp.apply(d) + + def test_custom_function_forward_mode_view_checks(self): + flag_to_error = { + "ok": None, + "not_a_view": "jvp is not returning a view", + "not_a_view_of_inp": "jvp is not returning a view of the given", + "not_a_view_of_inp_base": "jvp is not returning a view of the same base", + } + + class ViewFn(Function): + @staticmethod + def forward(ctx, foo, flag): + ctx.flag = flag + ctx.size = foo.size() + return foo.narrow(0, 0, 2) + + @staticmethod + def vjp(ctx, gO): + gI = gO.new_zeros(ctx.size) + gI.narrow(0, 0, 2).copy_(gO) + return gI, None + + @staticmethod + def jvp(ctx, gI, _): + res = gI.narrow(0, 0, 2) + if ctx.flag != "ok": + # Break the view in the gradients! + res = res.clone() + if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]: + # Result should be a view, just of the wrong thing + res = res.view_as(res) + return res + + inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) + + for flag, msg in flag_to_error.items(): + def test_fn(inp): + if flag == "not_a_view_of_inp_base": + inp = inp.view_as(inp) + return ViewFn.apply(inp, flag) + + if msg is None: + gradcheck(test_fn, inp, check_forward_ad=True) + else: + with self.assertRaisesRegex(RuntimeError, msg): + gradcheck(test_fn, inp, check_forward_ad=True) + + def test_custom_function_forward_mode_inplace_checks(self): + class InplaceFn(Function): + @staticmethod + def forward(ctx, foo, flag): + ctx.mark_dirty(foo) + ctx.flag = flag + foo.mul_(2) + return foo + + @staticmethod + def vjp(ctx, gO): + return 2 * gO, None + + @staticmethod + def jvp(ctx, gI, _): + if ctx.flag: + # Don't do the change inplace + return 2 * gI + else: + gI.mul_(2) + return gI + + inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) + + def test_fn(inp, flag): + inp = inp.clone() + return InplaceFn.apply(inp, flag) + + gradcheck(test_fn, (inp, False), check_forward_ad=True) + + with self.assertRaisesRegex(RuntimeError, "inplace custom Function is not modifying the forward mode gradients inplace"): + gradcheck(test_fn, (inp, True), check_forward_ad=True) + + def test_custom_function_forward_mode_wrong_formula(self): + class UserFn(Function): + @staticmethod + def forward(ctx, foo, should_fail): + ctx.should_fail = should_fail + return foo * 2 + + @staticmethod + def vjp(ctx, gO): + return 2 * gO, None + + @staticmethod + def jvp(ctx, gI, _): + if ctx.should_fail: + # Wrong gradient formula + return 3 * gI + else: + return 2 * gI + + inp = torch.rand(10, dtype=torch.double, requires_grad=True) + gradcheck(UserFn.apply, (inp, False), check_forward_ad=True) + + with self.assertRaisesRegex(RuntimeError, "Jacobian computed with forward mode mismatch for output 0"): + gradcheck(UserFn.apply, (inp, True), check_forward_ad=True) + def test_custom_function_local_inplace(self): class MyFn(torch.autograd.Function): @staticmethod @@ -6039,101 +6130,6 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks, test_case.assertEqual(self_variable.size(), self_variable.grad.size()) -class TestAutogradComplex(TestCase): - def test_view_func_for_complex_views(self): - # case 1: both parent and child have view_func - x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) - y = x.detach().requires_grad_(True) - - x0 = x.clone() - x1 = torch.view_as_complex(x0) - x2 = torch.view_as_real(x1) - x2.mul_(2) - x2.sum().backward() - - y0 = y.clone() - y0.mul_(2) - y0.sum().backward() - - self.assertEqual(x.grad, y.grad) - - # case 2: parent has view_func but child does not - x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) - y = x.detach().requires_grad_(True) - - def fn(a): - b = a.clone() - b1 = torch.view_as_complex(b) - b2 = b1.reshape(b1.numel()) - return b2 - - x0 = fn(x) - x0.mul_(2) - x0.sum().backward() - - y0 = fn(y) - y1 = y0.mul(2) - y1.sum().backward() - - self.assertEqual(x.grad, y.grad) - - # case 3: parent does not have a view_func but child does - x = torch.randn(10, dtype=torch.cdouble, requires_grad=True) - y = x.detach().requires_grad_(True) - - def fn(a, dim0_size=5): - b = a.clone() - b1 = b.reshape(dim0_size, 2) - b2 = torch.view_as_real(b1) - return b2 - - x0 = fn(x) - x0.mul_(2) - x0.sum().backward() - - y0 = fn(y) - y1 = y0.mul(2) - y1.sum().backward() - - self.assertEqual(x.grad, y.grad) - - def test_view_with_multi_output(self): - x = torch.randn(2, 2, 2, dtype=torch.double) - - x1 = torch.view_as_complex(x) - # Taking an invalid view should always be allowed as long as it is not - # modified inplace - res = x1.unbind(0) - - with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): - res[0] += torch.rand(2, requires_grad=True) - - x.requires_grad_(True) - x1 = torch.view_as_complex(x) - # Taking an invalid view should always be allowed as long as it is not - # modified inplace - res = x1.unbind(0) - - with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): - res[0] += torch.rand(2, requires_grad=True) - - def as_identity(self): - # view_as_real and view_as_complex behavior should be like an identity - def func(z): - z_ = torch.view_as_complex(z) - z_select = torch.select(z_, z_.dim() - 1, 0) - z_select_real = torch.view_as_real(z_select) - return z_select_real.sum() - - z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) - gradcheck(func, [z]) - func(z).backward() - - z1 = z.clone().detach().requires_grad_(True) - torch.select(z1, z1.dim() - 2, 0).sum().backward() - - self.assertEqual(z.grad, z1.grad) - class TestAutogradFunctional(TestCase): def _assert_same_struct(self, res, base): # base and res should be Tensors or tuple of Tensors with the same size @@ -8240,6 +8236,12 @@ def test_leaky_relu_inplace_with_zero_slope(self, device): expected = torch.tensor([0., 0., 1.], device=device) self.assertEqual(a.grad, expected) + a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=torch.bfloat16, requires_grad=True) + b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0) + b_bf16.backward(torch.ones(3, device=device)) + expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=torch.bfloat16) + self.assertEqual(a_bf16.grad, expected_bf16) + @onlyOnCPUAndCUDA def test_elu_inplace_with_neg_alpha(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) @@ -8473,7 +8475,7 @@ def test_copy_(self, device): # At the time of writing this test, copy_ is not generated from native_functions.yaml # there was a bug that bfloat16 was not recognized as floating. x = torch.randn(10, device=device, requires_grad=True) - floating_dt = [dt for dt in torch.testing.get_all_dtypes() if dt.is_floating_point] + floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point] for dt in floating_dt: y = torch.empty(10, device=device, dtype=dt) y.copy_(x) @@ -9500,6 +9502,11 @@ def fn(x1, x2): torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True) torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True) +# Import test cases from below autograd/ here. These are found +# implicitly by the loader, so Flake8 thinks they are unused, hence +# the suppressions. + +from autograd.test_complex import TestAutogradComplex # noqa: F401 # e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA instantiate_device_type_tests( diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index f952911d206f6..7153902841aa5 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -13,12 +13,17 @@ from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, iter_indices, TEST_WITH_ASAN, run_tests, - torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY, set_default_dtype) + torch_to_numpy_dtype_dict, TEST_SCIPY, set_default_dtype) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, - skipCUDAIfRocm, skipIf) -from torch.testing import all_types_and_complex_and, integral_types_and + skipCUDAIfRocm, skipIf, ops) +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes, + get_all_complex_dtypes, get_all_fp_dtypes, +) +from torch.testing._internal.common_methods_invocations import binary_ufuncs if TEST_SCIPY: import scipy.special @@ -89,6 +94,74 @@ def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: # TODO: update to use opinfos consistently class TestBinaryUfuncs(TestCase): + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_broadcasting(self, device, dtype, op): + for shape_lhs, shape_rhs in ( + ((1,), ()), + ((2,), ()), + ((1,), (2,)), + ((2,), (2,)), + ((2, 1), (2,)), + ((1, 2), (2,)), + ((3, 2), (2,)), + ((3, 2), (3, 2)), + ((1, 3, 2), (2,)), + ((1, 3, 2), (3, 2)), + ((3, 1, 2), (3, 2)), + ((1, 3, 2), (1, 3, 2)), + ((2, 3, 2), ()), + ((2, 3, 2), (2, 3, 2)), + ((3, 1, 2), (1, 3, 2)), + ): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs = make_tensor(shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + + actual = op(lhs, rhs).shape + expected = torch.broadcast_shapes(shape_lhs, shape_rhs) + + msg = ( + f"On {device}, torch.{op.name} broadcasts inputs of shapes {shape_lhs} and {shape_rhs} incorrectly: " + f"{actual} != {expected}" + ) + self.assertEqual(actual, expected, msg=msg) + + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_broadcast_python_scalar(self, device, dtype, op): + for shape_lhs in ((), (1,), (2,), (1, 2, 3),): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs_tensor = make_tensor((), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + rhs_python = rhs_tensor.item() + + actual = op(lhs, rhs_python) + expected = op(lhs, rhs_tensor) + + self.assertEqual( + actual.shape, + expected.shape, + msg=f"On {device}, torch.{op.name} broadcasts Python scalars different than 0d tensors.", + ) + + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_not_broadcastable(self, device, dtype, op): + for shape_lhs, shape_rhs in ( + ((2,), (3,)), + ((3, 1), (2, 1)), + ((1, 3, 2), (3,)), + ((3, 1, 2), (2, 1, 2)), + ): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs = make_tensor(shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + + try: + broadcasted_shape = op(lhs, rhs).shape + except RuntimeError: + continue + + msg = ( + f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into " + f"{broadcasted_shape}, although they are not broadcastable." + ) + raise AssertionError(msg) def test_add_broadcast_empty(self, device): # empty + empty @@ -279,7 +352,7 @@ def test_inplace_division(self, device): id_after = id(t) self.assertEqual(id_before, id_after) - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_complex=False)) + @dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) def test_div_rounding_modes(self, device, dtype): if dtype.is_floating_point: low, high = -10.0, 10.0 @@ -379,7 +452,7 @@ def test_divide_by_zero_rounding(self, device, dtype): actual = torch.divide(a, zero, rounding_mode=rounding_mode) self.assertEqual(actual, expect, exact_dtype=exact_dtype) - @dtypes(*torch.testing.get_all_dtypes( + @dtypes(*get_all_dtypes( include_bool=False, include_complex=False, include_bfloat16=False)) def test_div_rounding_numpy(self, device, dtype): info = (torch.finfo(dtype) if dtype.is_floating_point @@ -823,7 +896,7 @@ def test_pow_cuda_complex_extremal_failing(self, device, dtype): self.assertEqual(cpu_out, cuda_out) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_dtypes(include_bool=False, include_bfloat16=False))) + @dtypes(*(get_all_dtypes(include_bool=False, include_bfloat16=False))) def test_complex_scalar_pow_tensor(self, device, dtype): complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j] first_exp = make_tensor((100,), device, dtype, low=-2, high=2) @@ -1184,11 +1257,10 @@ def _wrapped_ifloordiv_scalar(a): # Also tests that reverse operations are equivalent to forward ops # NOTE: division ops are tested separately above def test_binary_ops_with_scalars(self, device): - for ops in ((operator.add, torch.add), - (operator.sub, torch.sub), - (operator.mul, torch.mul), - (operator.truediv, torch.div)): - python_op, torch_op = ops + for python_op, torch_op in ((operator.add, torch.add), + (operator.sub, torch.sub), + (operator.mul, torch.mul), + (operator.truediv, torch.div)): for a, b in product(range(-10, 10), range(-10, 10)): for op in (lambda x: x * .5, lambda x: math.floor(x)): @@ -1215,7 +1287,7 @@ def test_binary_ops_with_scalars(self, device): self.assertEqual(expected, python_op(first, second)) self.assertEqual(expected, torch_op(first, second)) - @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), torch.testing.get_all_dtypes(include_complex=False))) + @dtypes(*product(get_all_dtypes(include_complex=False), get_all_dtypes(include_complex=False))) def test_maximum_minimum_type_promotion(self, device, dtypes): a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) @@ -1223,7 +1295,7 @@ def test_maximum_minimum_type_promotion(self, device, dtypes): result = op(a, b) self.assertEqual(result.dtype, torch.result_type(a, b)) - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) + @dtypes(*(get_all_int_dtypes() + [torch.bool])) def test_maximum_minimum_int_and_bool(self, device, dtype): ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) @@ -1249,7 +1321,7 @@ def test_maximum_minimum_int_and_bool(self, device, dtype): self.assertEqual(out, numpy_result) @precisionOverride({torch.bfloat16: 1e-2}) - @dtypes(*(torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_fp_dtypes())) def test_maximum_minimum_float(self, device, dtype): ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) @@ -1277,7 +1349,7 @@ def test_maximum_minimum_float(self, device, dtype): self.assertEqual(tensor_result, numpy_result, exact_dtype=False) self.assertEqual(out, numpy_result, exact_dtype=False) - @dtypes(*(torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_fp_dtypes())) def test_maximum_minimum_float_nan_and_inf(self, device, dtype): # np.maximum and np.minimum functions compare input arrays element-wisely. # if one of the elements being compared is a NaN, then that element is returned. @@ -1313,7 +1385,7 @@ def test_maximum_minimum_float_nan_and_inf(self, device, dtype): self.assertEqual(tensor_result, numpy_result) self.assertEqual(out, numpy_result) - @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) def test_maximum_minimum_complex(self, device, dtypes): for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min, torch.fmax, torch.fmin): with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): @@ -1371,7 +1443,7 @@ def test_mul_intertype_scalar(self, device, dtype): self.assertEqual(x, 4.5) @onlyCPU - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sub(self, device, dtype): m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device) m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device) @@ -1433,8 +1505,8 @@ def test_min_max_binary_op_nan(self, device, dtype): self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) - @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False), - torch.testing.get_all_dtypes(include_complex=False))) + @dtypes(*product(get_all_dtypes(include_complex=False), + get_all_dtypes(include_complex=False))) def test_copysign(self, device, dtypes): def _test_copysign_numpy(a, b): torch_result = torch.copysign(a, b) @@ -1451,7 +1523,7 @@ def _test_copysign_numpy(a, b): expected = torch.from_numpy(np.copysign(np_a, np_b)) # To handle inconsistencies of type promotion between PyTorch and Numpy # Applied for both arguments having integral precision and bfloat16 - types = [torch.bool, torch.bfloat16] + torch.testing.get_all_int_dtypes() + types = [torch.bool, torch.bfloat16] + get_all_int_dtypes() if a.dtype in types or b.dtype in types: promoted_type = torch.promote_types(torch_result.dtype, expected.dtype) torch_result = torch_result.to(promoted_type) @@ -1496,7 +1568,7 @@ def _test_copysign_numpy(a, b): for case in cases: _test_copysign_numpy(torch.tensor([case], device=device, dtype=dtypes[0]), b) - if dtypes[1] in torch.testing.get_all_fp_dtypes(): + if dtypes[1] in get_all_fp_dtypes(): a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) for case in cases: _test_copysign_numpy(a, torch.tensor([case], device=device, dtype=dtypes[1])) @@ -1548,25 +1620,25 @@ def test_divmul_scalar(self, device, dtype): res = scale * x self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) + @dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) def test_floor_divide_tensor(self, device, dtype): x = torch.randn(10, device=device).mul(30).to(dtype) y = torch.arange(1, 11, dtype=dtype, device=device) - with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): + with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"): z = x // y z_alt = torch.trunc(x.double() / y.double()).to(dtype) self.assertEqual(z.dtype, x.dtype) self.assertEqual(z, z_alt) - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) + @dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) def test_floor_divide_scalar(self, device, dtype): x = torch.randn(100, device=device).mul(10).to(dtype) - with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): + with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"): z = x // 3 z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device) @@ -1595,7 +1667,7 @@ def test_floor_divide_out(self, device, dtype): self.assertEqual(o, torch.floor_divide(x.float(), y.float())) @onlyCPU - @dtypes(*torch.testing.get_all_math_dtypes('cpu')) + @dtypes(*get_all_math_dtypes('cpu')) def test_rdiv(self, device, dtype): if dtype is torch.float16: return @@ -1607,7 +1679,7 @@ def test_rdiv(self, device, dtype): z = torch.tensor([30 / v.item() for v in x], device=device) self.assertEqual(y, z, exact_dtype=False) - @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False)) + @dtypes(*get_all_fp_dtypes(include_bfloat16=False)) def test_fmod_remainder_by_zero_float(self, device, dtype): fn_list = (torch.fmod, torch.remainder) for fn in fn_list: @@ -1619,7 +1691,7 @@ def test_fmod_remainder_by_zero_float(self, device, dtype): @onlyOnCPUAndCUDA # Check Issue https://github.com/pytorch/pytorch/issues/48130 @skipCUDAIfRocm # Error happens on both ROCM and XLA - @dtypes(*torch.testing.get_all_int_dtypes()) + @dtypes(*get_all_int_dtypes()) def test_fmod_remainder_by_zero_integral(self, device, dtype): fn_list = (torch.fmod, torch.remainder) for fn in fn_list: @@ -1644,7 +1716,7 @@ def test_fmod_remainder_by_zero_integral(self, device, dtype): value = 255 if dtype == torch.uint8 else -1 self.assertTrue(torch.all(fn(x, zero) == value)) - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) + @dtypes(*get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) def test_fmod_remainder(self, device, dtype): # Use numpy as reference def _helper(x, mod, fns_list): @@ -1681,7 +1753,7 @@ def _helper(x, mod, fns_list): # Mods: Integer, Float, Tensor, Non-contiguous Tensor mods = [3, 2.3, mod, mod.t()] # mod with floating-point dtype - if dtype in torch.testing.get_all_int_dtypes(): + if dtype in get_all_int_dtypes(): mod_float = make_tensor((10, 10), device=device, dtype=torch.float, low=-9, high=9) mod[mod == 0] = 1 mods.append(mod_float) @@ -1902,7 +1974,7 @@ def test_floor_divide_zero(self, device, dtype): a // b @unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN") - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_muldiv_scalar(self, device, dtype): x = make_tensor((10, 3), device, dtype, low=None, high=None) s = make_tensor((1,), 'cpu', dtype, low=None, high=None).item() @@ -1912,7 +1984,7 @@ def test_muldiv_scalar(self, device, dtype): self.assertEqual(x / s, x / y) self.assertEqual(s / x, y / x) - @dtypes(*tuple(itertools.combinations_with_replacement(torch.testing.get_all_dtypes(), 2))) + @dtypes(*tuple(itertools.combinations_with_replacement(get_all_dtypes(), 2))) def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): # issue #42660 # testing all combinations of broadcasting and type promotion @@ -2094,8 +2166,8 @@ def test_bitwise_shift_float(self, device): self.assertEqual(torch_op(a, 2.2), expected_op(a, 2.2)) @onlyOnCPUAndCUDA - @dtypes(*list(product(torch.testing.get_all_dtypes(include_complex=False), - torch.testing.get_all_dtypes(include_complex=False)))) + @dtypes(*list(product(get_all_dtypes(include_complex=False), + get_all_dtypes(include_complex=False)))) def test_heaviside(self, device, dtypes): input_dtype = dtypes[0] values_dtype = dtypes[1] @@ -2154,8 +2226,8 @@ def test_heaviside_cross_device(self, device): with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): torch.heaviside(y, x) - @dtypes(*list(product(torch.testing.get_all_complex_dtypes(), - torch.testing.get_all_complex_dtypes()))) + @dtypes(*list(product(get_all_complex_dtypes(), + get_all_complex_dtypes()))) def test_heaviside_complex(self, device, dtypes): input_dtype = dtypes[0] values_dtype = dtypes[1] @@ -2197,15 +2269,15 @@ def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): getattr(a, op + '_')(b) self.assertEqual(expected_res, a) - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*product(get_all_dtypes(), get_all_dtypes())) def test_logical_xor(self, device, dtypes): self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]) - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*product(get_all_dtypes(), get_all_dtypes())) def test_logical_and(self, device, dtypes): self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]) - @dtypes(*product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*product(get_all_dtypes(), get_all_dtypes())) def test_logical_or(self, device, dtypes): self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) @@ -2309,7 +2381,7 @@ def test_logaddexp2(self, device, dtype): self._test_logaddexp(device, dtype, base2=True) def test_add(self, device): - dtypes = [torch.float, torch.double] + torch.testing.get_all_complex_dtypes() + dtypes = [torch.float, torch.double] + get_all_complex_dtypes() for dtype in dtypes: # [res] torch.add([res,] tensor1, tensor2) m1 = torch.randn(100, 100, dtype=dtype, device=device) @@ -2510,7 +2582,7 @@ def test_bool_tensor_comparison_ops(self, device): torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) self.assertFalse(a.equal(b)) - @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) + @dtypes(*get_all_dtypes(include_complex=False)) def test_logical(self, device, dtype): if dtype != torch.bool: x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype) @@ -2596,6 +2668,11 @@ def test_x(sizes, dim, x, device): test_x((1, 10), 0, [1.0], device) test_x((0, 2), 0, [], device) test_x((0, 2), 1, [1.0, 2.0], device) + test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device) + test_x((2, 3, 4), 0, [1.0, 2.0], device) + test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) + test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) + test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device) with self.assertRaisesRegex( IndexError, 'Dimension out of range'): @@ -2654,6 +2731,10 @@ def test_empty_x(sizes, dim, x, device): test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) test_x((1, 10), 0, [1.0], device) test_x((0, 2), 1, [1, 2], device) + test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device) + test_x((2, 3, 4), 0, [1.0, 2.0], device) + test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) + test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) test_empty_x((0, 2), 0, [], device) # SciPy failing when x == [], but our version returns empty @@ -2687,8 +2768,8 @@ def test_pow_scalar_overloads_mem_overlap(self, device, dtype): self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: torch.pow(42, input, out=out)) - @dtypes(*list(product(torch.testing.get_all_dtypes(include_bool=False), - torch.testing.get_all_dtypes(include_bool=False)))) + @dtypes(*list(product(get_all_dtypes(include_bool=False), + get_all_dtypes(include_bool=False)))) def test_float_power(self, device, dtypes): def to_np(value): if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: @@ -2784,8 +2865,8 @@ def _promo_helper(x, y): torch.Tensor.float_power_(base.clone(), exp) @skipIf(not TEST_SCIPY, "Scipy required for the test.") - @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False), - torch.testing.get_all_dtypes(include_complex=False, include_bfloat16=False))) + @dtypes(*product(get_all_dtypes(include_complex=False, include_bfloat16=False), + get_all_dtypes(include_complex=False, include_bfloat16=False))) def test_xlogy_xlog1py(self, device, dtypes): x_dtype, y_dtype = dtypes @@ -2796,7 +2877,7 @@ def out_variant_helper(torch_fn, x, y): self.assertEqual(expected, out) def xlogy_inplace_variant_helper(x, y): - if x.dtype in torch.testing.get_all_int_dtypes() + [torch.bool]: + if x.dtype in get_all_int_dtypes() + [torch.bool]: with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): x.clone().xlogy_(y) @@ -2923,10 +3004,10 @@ def _compare_helper(x, y, torch_fn, reference_fn): _compare_helper(t, zeros, *xlog1py_fns) _compare_helper(t, 0., *xlog1py_fns) - @dtypes(*product(torch.testing.get_all_dtypes(include_complex=False, - include_half=False, include_bfloat16=False), - torch.testing.get_all_dtypes(include_complex=False, - include_half=False, include_bfloat16=False))) + @dtypes(*product(get_all_dtypes(include_complex=False, + include_half=False, include_bfloat16=False), + get_all_dtypes(include_complex=False, + include_half=False, include_bfloat16=False))) @skipIf(not TEST_SCIPY, "Scipy required for the test.") def test_zeta(self, device, dtypes): x_dtype, q_dtype = dtypes diff --git a/test/test_buffer_protocol.py b/test/test_buffer_protocol.py index c797b913f033c..619386e6d5665 100644 --- a/test/test_buffer_protocol.py +++ b/test/test_buffer_protocol.py @@ -1,4 +1,5 @@ import torch.testing._internal.common_utils as common +from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes @@ -23,7 +24,7 @@ def _run_test(self, shape, dtype, count=-1, first=0, offset=None, **kwargs): if offset is None: offset = first * get_dtype_size(dtype) - numpy_original = common.make_tensor(shape, torch.device("cpu"), dtype).numpy() + numpy_original = make_tensor(shape, torch.device("cpu"), dtype).numpy() original = memoryview(numpy_original) # First call PyTorch's version in case of errors. # If this call exits successfully, the NumPy version must also do so. @@ -125,7 +126,7 @@ def test_invalid_positional_args(self, device, dtype): @dtypes(*common.torch_to_numpy_dtype_dict.keys()) def test_shared_buffer(self, device, dtype): - x = common.make_tensor((1,), device, dtype) + x = make_tensor((1,), device, dtype) # Modify the whole tensor arr, tensor = self._run_test(SHAPE, dtype) tensor[:] = x @@ -158,7 +159,7 @@ def test_not_a_buffer(self, device, dtype): @dtypes(*common.torch_to_numpy_dtype_dict.keys()) def test_non_writable_buffer(self, device, dtype): - numpy_arr = common.make_tensor((1,), device, dtype).numpy() + numpy_arr = make_tensor((1,), device, dtype).numpy() byte_arr = numpy_arr.tobytes() with self.assertWarnsOnceRegex(UserWarning, r"The given buffer is not writable."): diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index 0c95ae39c582d..7efd40178a160 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -67,7 +67,7 @@ def forward(self, arg): self.assertEqual(len(inflated), 1) self.assertEqual(len(inflated[0]), 1) self.assertEqual(raw_data.shape, decoded_data.shape) - self.assertTrue(torch.allclose(raw_data, decoded_data, atol=0.1, rtol=1e-01)) + self.assertEqual(raw_data, decoded_data, atol=0.1, rtol=1e-01) # Check if fb::image_decode_to_NCHW works as expected with open("caffe2/test/test_img/p1.jpg", "rb") as fp: @@ -76,4 +76,4 @@ def forward(self, arg): byte_tensor = torch.tensor(list(fp.read())).byte() im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias) self.assertEqual(raw_data.shape, im2_tensor.shape) - self.assertTrue(torch.allclose(raw_data, im2_tensor, atol=0.1, rtol=1e-01)) + self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01) diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index a0fb535da8a86..62263e130fd8b 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import io import textwrap -from typing import List +from typing import List, Optional, Dict import torch import torch.utils.bundled_inputs @@ -324,5 +324,118 @@ def forward(self, arg): ) self.assertEqual(bundled_model2.get_all_bundled_inputs(), [(torch.ones(2),)]) + + def test_dict_args(self): + class MyModel(torch.nn.Module): + def forward( + self, + arg1: Optional[Dict[str, torch.Tensor]], + arg2: Optional[List[torch.Tensor]], + arg3: torch.Tensor, + ): + if arg1 is None: + return arg3 + elif arg2 is None: + return arg1["a"] + arg1["b"] + else: + return arg1["a"] + arg1["b"] + arg2[0] + + small_sample = dict( + a=torch.zeros([10, 20]), + b=torch.zeros([1, 1]), + c=torch.zeros([10, 20]), + ) + small_list = [torch.zeros([10, 20])] + + big_sample = dict( + a=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + b=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + c=torch.zeros([1 << 5, 1 << 8, 1 << 10]), + ) + big_list = [torch.zeros([1 << 5, 1 << 8, 1 << 10])] + + def condensed(t): + ret = torch.empty_like(t).flatten()[0].clone().expand(t.shape) + assert ret.storage().size() == 1 + # ret.storage()[0] = 0 + return ret + + def bundle_optional_dict_of_randn(template): + return torch.utils.bundled_inputs.InflatableArg( + value=( + None + if template is None + else {k: condensed(v) for (k, v) in template.items()} + ), + fmt="{}", + fmt_fn=""" + def {}(self, value: Optional[Dict[str, Tensor]]): + if value is None: + return None + output = {{}} + for k, v in value.items(): + output[k] = torch.randn_like(v) + return output + """, + ) + + def bundle_optional_list_of_randn(template): + return torch.utils.bundled_inputs.InflatableArg( + value=(None if template is None else [condensed(v) for v in template]), + fmt="{}", + fmt_fn=""" + def {}(self, value: Optional[List[Tensor]]): + if value is None: + return None + output = [] + for v in value: + output.append(torch.randn_like(v)) + return output + """, + ) + + out : List[str] = [] + sm = torch.jit.script(MyModel()) + original_size = model_size(sm) + small_inputs = ( + bundle_optional_dict_of_randn(small_sample), + bundle_optional_list_of_randn(small_list), + torch.zeros([3, 4]), + ) + big_inputs = ( + bundle_optional_dict_of_randn(big_sample), + bundle_optional_list_of_randn(big_list), + torch.zeros([1 << 5, 1 << 8, 1 << 10]), + ) + + torch.utils.bundled_inputs.augment_model_with_bundled_inputs( + sm, + [ + big_inputs, + small_inputs, + ], + _receive_inflate_expr=out, + ) + augmented_size = model_size(sm) + # assert the size has not increased more than 8KB + self.assertLess(augmented_size, original_size + (1 << 13)) + + loaded = save_and_load(sm) + inflated = loaded.get_all_bundled_inputs() + self.assertEqual(len(inflated[0]), len(small_inputs)) + + methods, _ = torch.utils.bundled_inputs._get_bundled_inputs_attributes_and_methods( + loaded + ) + + # One Function (forward) + # two bundled inputs (big_inputs and small_inputs) + # two args which have InflatableArg with fmt_fn + # 1 * 2 * 2 = 4 + self.assertEqual( + sum([method.startswith("_inflate_helper") for method in methods]), 4 + ) + + if __name__ == '__main__': run_tests() diff --git a/test/test_complex.py b/test/test_complex.py index 45482efbae56d..eee7a6a51534e 100644 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -1,11 +1,12 @@ import torch from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_dtype import get_all_complex_dtypes devices = (torch.device('cpu'), torch.device('cuda:0')) class TestComplexTensor(TestCase): - @dtypes(*torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_complex_dtypes()) def test_to_list(self, device, dtype): # test that the complex float tensor has expected values and # there's no garbage value in the resultant list diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 307df0eed5e9a..cf35e6b13265d 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -19,11 +19,11 @@ try: if HAS_PYTEST: cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp") - msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu") + ort_extension = pytest.importorskip("torch_test_cpp_extension.ort") rng_extension = pytest.importorskip("torch_test_cpp_extension.rng") else: import torch_test_cpp_extension.cpp as cpp_extension - import torch_test_cpp_extension.msnpu as msnpu_extension + import torch_test_cpp_extension.ort as ort_extension import torch_test_cpp_extension.rng as rng_extension except ImportError as e: raise RuntimeError( @@ -100,45 +100,45 @@ def test_optional(self): self.assertFalse(has_value) -class TestMSNPUTensor(common.TestCase): +class TestORTTensor(common.TestCase): def test_unregistered(self): a = torch.arange(0, 10, device='cpu') with self.assertRaisesRegex(RuntimeError, "Could not run"): - b = torch.arange(0, 10, device='msnpu') + b = torch.arange(0, 10, device='ort') def test_zeros(self): a = torch.empty(5, 5, device='cpu') self.assertEqual(a.device, torch.device('cpu')) - b = torch.empty(5, 5, device='msnpu') - self.assertEqual(b.device, torch.device('msnpu', 0)) - self.assertEqual(msnpu_extension.get_test_int(), 0) + b = torch.empty(5, 5, device='ort') + self.assertEqual(b.device, torch.device('ort', 0)) + self.assertEqual(ort_extension.get_test_int(), 0) self.assertEqual(torch.get_default_dtype(), b.dtype) - c = torch.empty((5, 5), dtype=torch.int64, device='msnpu') - self.assertEqual(msnpu_extension.get_test_int(), 0) + c = torch.empty((5, 5), dtype=torch.int64, device='ort') + self.assertEqual(ort_extension.get_test_int(), 0) self.assertEqual(torch.int64, c.dtype) def test_add(self): - a = torch.empty(5, 5, device='msnpu', requires_grad=True) - self.assertEqual(msnpu_extension.get_test_int(), 0) + a = torch.empty(5, 5, device='ort', requires_grad=True) + self.assertEqual(ort_extension.get_test_int(), 0) - b = torch.empty(5, 5, device='msnpu') - self.assertEqual(msnpu_extension.get_test_int(), 0) + b = torch.empty(5, 5, device='ort') + self.assertEqual(ort_extension.get_test_int(), 0) c = a + b - self.assertEqual(msnpu_extension.get_test_int(), 1) + self.assertEqual(ort_extension.get_test_int(), 1) def test_conv_backend_override(self): # To simplify tests, we use 4d input here to avoid doing view4d( which # needs more overrides) in _convolution. - input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True) - weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True) - bias = torch.empty(6, device='msnpu') + input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True) + weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True) + bias = torch.empty(6, device='ort') # Make sure forward is overriden out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1) - self.assertEqual(msnpu_extension.get_test_int(), 2) + self.assertEqual(ort_extension.get_test_int(), 2) self.assertEqual(out.shape[0], input.shape[0]) self.assertEqual(out.shape[1], weight.shape[0]) @@ -146,7 +146,7 @@ def test_conv_backend_override(self): # Double backward is dispatched to _convolution_double_backward. # It is not tested here as it involves more computation/overrides. grad = torch.autograd.grad(out, input, out, create_graph=True) - self.assertEqual(msnpu_extension.get_test_int(), 3) + self.assertEqual(ort_extension.get_test_int(), 3) self.assertEqual(grad[0].shape, input.shape) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 073835277e678..89d9af10e0d35 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -869,11 +869,29 @@ def test_custom_compound_op_autograd(self): gradcheck(torch.ops.my.add, [a, b], eps=1e-2) - @unittest.skipIf(not has_breakpad(), "Breakpad library must be present on system for crash handler") - @unittest.skipIf(TEST_WITH_ASAN, "ASAN disables the crash handler's signal handler") - def test_crash_handler(self): - def run_test(stderr_file, destination): - # Code to enable dumps and trigger a segfault + @staticmethod + def _crash_handler_test_process(stderr_file, destination): + # Code to enable dumps and trigger a segfault + if sys.platform == "win32": + destination = destination.replace("\\", "\\\\") + csrc = textwrap.dedent(f""" + #include + #include + #include + #include + #include + + int fail() {{ + std::wstring_convert> converter; + std::string narrow("{destination}"); + std::wstring wide = converter.from_bytes(narrow); + torch::crash_handler::enable_minidumps(wide.c_str()); + + volatile int* bad = nullptr; + return *bad; + }} + """) + else: csrc = textwrap.dedent(f""" #include @@ -885,29 +903,32 @@ def run_test(stderr_file, destination): }} """) - # Some special stuff to overwrite stderr for a C++ extension - # Copied from: https://stackoverflow.com/questions/8804893/redirect-stdout-from-python-for-c-calls - sys.stdout.flush() - newstdout = os.dup(2) - devnull = os.open(stderr_file, os.O_WRONLY) - os.dup2(devnull, 2) - os.close(devnull) - sys.stdout = os.fdopen(newstdout, 'w') - - module = torch.utils.cpp_extension.load_inline( - name="segfault", - cpp_sources=csrc, - functions=["fail"], - ) - module.fail() + # Some special stuff to overwrite stderr for a C++ extension + # Copied from: https://stackoverflow.com/questions/8804893/redirect-stdout-from-python-for-c-calls + sys.stdout.flush() + newstdout = os.dup(2) + devnull = os.open(stderr_file, os.O_WRONLY) + os.dup2(devnull, 2) + os.close(devnull) + sys.stdout = os.fdopen(newstdout, 'w') + module = torch.utils.cpp_extension.load_inline( + name="segfault", + cpp_sources=csrc, + functions=["fail"], + ) + module.fail() - with tempfile.TemporaryDirectory() as temp_dir, tempfile.NamedTemporaryFile() as stderr: + @unittest.skipIf(TEST_WITH_ASAN, "ASAN disables the crash handler's signal handler") + @unittest.skipIf(not has_breakpad(), "Built without breakpad") + def test_crash_handler(self): + with tempfile.TemporaryDirectory() as temp_dir, tempfile.NamedTemporaryFile(delete=not sys.platform == "win32") as stderr: # Use multiprocessing to spin up a separate process to make catching # the segfault easier - p = Process(target=run_test, args=(stderr.name, temp_dir)) + p = Process(target=self._crash_handler_test_process, args=(stderr.name, temp_dir)) p.start() p.join() + with open(stderr.name) as f: result = f.read().strip() diff --git a/test/test_cuda.py b/test/test_cuda.py index 55bab2ee4ebcd..cddd15a7670e9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -46,12 +46,15 @@ TEST_LARGE_TENSOR = TEST_CUDA TEST_MEDIUM_TENSOR = TEST_CUDA TEST_CUDNN = TEST_CUDA +TEST_BF16 = False if TEST_CUDA: torch.ones(1).cuda() # initialize cuda context TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0')))) TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9 + TEST_BF16 = torch.cuda.is_bf16_supported() + types = [ torch.FloatTensor, @@ -2036,7 +2039,7 @@ def test_grad_scaling_unscale(self, dtype=torch.float): else: self.assertEqual(found_inf, 0.0) for grad in grads: - self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) # When passing lists with mismatched dtypes to a raw # _amp_foreach_non_finite_check_and_unscale_ call, @@ -2044,7 +2047,7 @@ def test_grad_scaling_unscale(self, dtype=torch.float): grads = [g.clone(), g.to(dtype=torch.float16)] torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) for grad in grads: - self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) # Passing lists with mismatched devices to a raw # _amp_foreach_non_finite_check_and_unscale_ call should raise errors. @@ -2084,7 +2087,7 @@ def perfect_storm_grads(inject_inf): # No inf was injected, ensures unscaling worked normally. self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0) for grad in grads: - self.assertTrue(torch.allclose(grad, torch.ones_like(grad), atol=1e-7)) + self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) else: # inf was injected, ensures inf was found. self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1) @@ -2136,7 +2139,7 @@ def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float): found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur] self.assertEqual(found_inf, 0.0) - self.assertTrue(torch.allclose(p.grad.to_dense(), (s / 4).to_dense())) + self.assertEqual(p.grad.to_dense(), (s / 4).to_dense()) v = torch.FloatTensor([16., 32., float('inf')]) p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cuda", dtype=dtype) @@ -2158,7 +2161,7 @@ def test_grad_scaling_unscale_sparse(self, device="cuda", dtype=torch.float): found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur] self.assertEqual(found_inf, 0.0) - self.assertTrue(torch.allclose(p.grad.to_dense(), (s.half() / 4).to_dense())) + self.assertEqual(p.grad.to_dense(), (s.half() / 4).to_dense()) # Creates fp16 sparse tensor with duplicated indices (uncoalesced). The uncoalesced representation # does not overflow in fp16, but the coalesced representation would, because 64000 + 64000 > fp16 max. @@ -2465,7 +2468,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api): for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()), chain(mod_scaling0.parameters(), mod_scaling1.parameters())): - self.assertTrue(torch.allclose(c, s, atol=1e-7)) + self.assertEqual(c, s, rtol=1e-5, atol=1e-7) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_grad_scaling_multigpu(self): @@ -2534,7 +2537,7 @@ def run(model0, model1, optimizer0, optimizer1, try_scaling_api): for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()), chain(mod_scaling0.parameters(), mod_scaling1.parameters())): - self.assertTrue(torch.allclose(c, s, atol=1e-7)) + self.assertEqual(c, s, rtol=1e-5, atol=1e-7) def test_cublas_multiple_threads_same_device(self): # Note, these parameters should be very carefully tuned @@ -2707,9 +2710,9 @@ def cast(val, to_type): if add_kwargs is None: add_kwargs = {} - + fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 self.assertFalse(torch.is_autocast_enabled()) - with torch.autocast('cuda', ): + with torch.autocast('cuda', dtype=fast_dtype): self.assertTrue(torch.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type @@ -2784,6 +2787,27 @@ def test_autocast_torch_fp16(self): if not skip_test: self._run_autocast_outofplace(op, args, torch.float16) + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + def test_autocast_torch_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = op_with_args[2] # TEST_WITH_ROCM + should_error_from_not_implemented = 'cudnn' in op or 'prelu' in op or 'thnn' in op \ + or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell' + if not skip_test: + if should_error_from_not_implemented: + with self.assertRaises(RuntimeError, msg=str(op) + ' should not be supported for bfloat16!'): + self._run_autocast_outofplace(op, args, torch.bfloat16) + else: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace(op, args, torch.bfloat16) + else: + with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'): + self._run_autocast_outofplace(op, args, torch.bfloat16) + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_autocast_torch_fp32(self): for op_with_args in self.autocast_lists.torch_fp32: @@ -2806,6 +2830,18 @@ def test_autocast_nn_fp16(self): for op, args in self.autocast_lists.nn_fp16: self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._nn) + + + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + def test_autocast_nn_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.nn_fp16: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn) + else: + with self.assertRaisesRegex(RuntimeError, 'Device does not support bfloat16'): + self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn) + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_autocast_nn_fp32(self): for op, args in self.autocast_lists.nn_fp32: @@ -3013,7 +3049,7 @@ def test_autocast_rnn(self): # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. - self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward") + self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward0") out.sum().backward() grads = [p.grad.clone() for p in rnn.parameters()] @@ -3089,7 +3125,7 @@ def test_graph_capture_simple(self): with torch.cuda.stream(s): a = torch.full((1000,), 1, device="cuda") - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() torch.cuda.empty_cache() g.capture_begin() b = a @@ -3125,7 +3161,7 @@ def run(op, kwargs): with torch.cuda.stream(stream): torch.cuda.manual_seed(5) - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() torch.cuda.empty_cache() g.capture_begin() graph_out = graph_in @@ -3212,7 +3248,7 @@ def run(module, op, args, kwargs): with torch.cuda.stream(stream): torch.cuda.manual_seed(5) - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() torch.cuda.empty_cache() if (module == "torch"): g.capture_begin() @@ -3279,14 +3315,14 @@ def func_with_temps(t, val): s = torch.cuda.Stream() for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda._Graph() - g1 = torch.cuda._Graph() + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() a = torch.ones((size,), device="cuda") s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () + g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () g0.capture_begin(*g0_args) b = a.clone() for _ in range(5): @@ -3343,8 +3379,8 @@ def func_with_temps(t, val): s = torch.cuda.Stream() for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda._Graph() - g1 = torch.cuda._Graph() + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() s0 = torch.cuda.Stream() s1 = torch.cuda.Stream() @@ -3353,7 +3389,7 @@ def func_with_temps(t, val): s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () + g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () g0.capture_begin(*g0_args) b = a.clone() for _ in range(5): @@ -3407,13 +3443,13 @@ def test_graph_three_successive(self): for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): a = torch.ones((size,), device="cuda") - g0 = torch.cuda._Graph() - g1 = torch.cuda._Graph() - g2 = torch.cuda._Graph() + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() + g2 = torch.cuda.CUDAGraph() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - g0_args = (torch.cuda._graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () + g0_args = (torch.cuda.graph_pool_handle(),) if share_mem == "via graph_pool_handle()" else () g0.capture_begin(*g0_args) b = a.clone() c = b + 1 @@ -3499,7 +3535,7 @@ def test_graph_memory_stats_and_use_result_after_destroy_graph(self): delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder delta_active_bytes = numel * elem - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): # Allocation stat estimates assume input is created on the same stream as capture_begin() @@ -3573,7 +3609,7 @@ def test_graph_record_stream(self): s0 = torch.cuda.Stream() s1 = torch.cuda.Stream() s2 = torch.cuda.Stream() - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() torch.cuda.synchronize() with torch.cuda.stream(s0): @@ -3620,7 +3656,7 @@ def test_graph_cudnn_dropout(self): y = model(x) - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): @@ -3638,7 +3674,7 @@ def test_graph_grad_scaling(self): torch.cuda.empty_cache() scaler = torch.cuda.amp.GradScaler(init_scale=4.) - g = torch.cuda._Graph() + g = torch.cuda.CUDAGraph() s = torch.cuda.Stream() weight = torch.ones((100,), device="cuda", requires_grad=True) @@ -3646,18 +3682,20 @@ def test_graph_grad_scaling(self): static_input = torch.ones_like(weight) static_grad = torch.ones_like(weight) + # warmup + s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - # warmup loss = (weight.half() * static_input).sum() scaler.scale(loss).backward() - opt.zero_grad(set_to_none=True) - # capture - g.capture_begin() + torch.cuda.current_stream().wait_stream(s) + + opt.zero_grad(set_to_none=True) + + # capture + with torch.cuda.graph(g): loss = (weight.half() * static_input).sum() scaler.scale(loss).backward() - g.capture_end() - torch.cuda.current_stream().wait_stream(s) input_vals = [5, 20000, 5, 40000] # If the scale gets updated properly, these are the scale, growth tracker, @@ -3678,6 +3716,71 @@ def test_graph_grad_scaling(self): self.assertEqual(scaler._scale, scale) self.assertEqual(scaler._growth_tracker, growth_tracker) + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_make_graphed_callables(self): + torch.manual_seed(5) + torch.cuda.manual_seed(5) + + N, D_in, H, D_out = 640, 4096, 2048, 1024 + + models = [] + for _ in range(2): + model_section1 = torch.nn.Sequential(torch.nn.Linear(D_in, H), + torch.nn.Dropout(p=0.1)).cuda() + model_section2 = torch.nn.Sequential(torch.nn.Linear(H, D_out), + torch.nn.Dropout(p=0.2)).cuda() + models.append(torch.nn.Sequential(model_section1, model_section2)) + + model_graphed = models[0] + model_control = models[1] + + model_graphed.load_state_dict(model_control.state_dict()) + + opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) + opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) + + x = torch.randn(N, D_in, device='cuda') + h = torch.randn(N, H, device='cuda', requires_grad=True) + y_pred = torch.randn(N, D_out, device='cuda', requires_grad=True) + y = torch.randn(N, D_out, device='cuda') + + loss_fn_control = torch.nn.functional.mse_loss + relu_control = torch.nn.functional.relu + + # This is a good stress test. It graphs four callables: two Modules and two python functions. + model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \ + torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control), + ((x,), (h,), (y_pred,), (y_pred, y))) + + real_inputs = [torch.rand_like(x) for _ in range(10)] + real_targets = [torch.rand_like(y) for _ in range(10)] + + for m, opt, relu, loss_fn in zip((model_graphed, model_control), + (opt_graphed, opt_control), + (relu_graphed, relu_control), + (loss_fn_graphed, loss_fn_control)): + # Resets RNC states before iterations for graphed and ungraphed models, + # so dropout math should be bitwise identical for both. + torch.manual_seed(5) + torch.cuda.manual_seed(5) + for data, target in zip(real_inputs, real_targets): + opt.zero_grad(set_to_none=True) + y_pred = m(data) + y_pred = relu(y_pred) + loss = loss_fn(y_pred, target) + loss.backward() + opt.step() + + for p, pc in zip(model_graphed.parameters(), model_control.parameters()): + self.assertEqual(p, pc) + + # We graphed the models in training mode. Eval should still run ungraphed. + model_graphed.eval() + model_control.eval() + self.assertEqual(model_graphed(real_inputs[0]), model_control(real_inputs[0])) + def test_batch_norm_gather_stats(self): input = torch.randn(1, 3, 3, 3, device='cuda') mean, invstd = torch.batch_norm_gather_stats( diff --git a/test/test_dataloader.py b/test/test_dataloader.py index c68d7e2e14b33..5050feca3a373 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -13,14 +13,27 @@ import warnings import tempfile from torch import multiprocessing as mp -from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset +from torch.utils.data import ( + ChainDataset, + ConcatDataset, + DataLoader, + DataLoader2, + Dataset, + IterableDataset, + Subset, + TensorDataset, + communication, + _utils +) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split +from torch.utils.data.datapipes.iter import IterableWrapper from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, load_tests, TEST_WITH_TSAN, IS_SANDCASTLE) + try: import psutil HAS_PSUTIL = True @@ -33,6 +46,17 @@ else: warnings.warn(err_msg) +try: + import dill + # XXX: By default, dill writes the Pickler dispatch table to inject its + # own logic there. This globally affects the behavior of the standard library + # pickler for any user who transitively depends on this module! + # Undo this extension to avoid altering the behavior of the pickler globally. + dill.extend(use_dill=False) + HAS_DILL = True +except ImportError: + HAS_DILL = False +skipIfNoDill = unittest.skipIf(not HAS_DILL, "no dill") # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -708,7 +732,7 @@ def __getitem__(self, idx): # Should be used as worker_init_fn with TestWorkerInfoDataset. # See _test_get_worker_info below for usage. -def test_worker_info_init_fn(worker_id): +def _test_worker_info_init_fn(worker_id): worker_info = torch.utils.data.get_worker_info() assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id" assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id" @@ -738,7 +762,7 @@ def _test_get_worker_info(): dataset = TestWorkerInfoDataset(6, batch_size, num_workers) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, - worker_init_fn=test_worker_info_init_fn) + worker_init_fn=_test_worker_info_init_fn) it = iter(dataloader) data = [] for d in it: @@ -747,7 +771,7 @@ def _test_get_worker_info(): data = torch.cat(data, 0) for d in data: # each `d` is a [worker_id, worker_pid] pair, which is set in - # test_worker_info_init_fn + # _test_worker_info_init_fn assert d[1] == worker_pids[d[0]] # get_worker_info returns None in main proc after data loading assert torch.utils.data.get_worker_info() is None @@ -1934,6 +1958,49 @@ def test_excessive_thread_creation_warning(self): dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) +@unittest.skipIf( + TEST_WITH_TSAN, + "Fails with TSAN with the following error: starting new threads after multi-threaded " + "fork is not supported. Dying (set die_after_fork=0 to override)") +class TestDataLoader2(TestCase): + @skipIfNoDill + def test_basics(self): + # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order + # of traversing workers + dp = IterableWrapper(list(range(1000))) + dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) + dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) + dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2, parallelism_mode='thread') + self.assertEqual(list(dl), list(dl2)) + self.assertEqual(list(dl), list(dl2_threading)) + + + +@unittest.skipIf( + TEST_WITH_TSAN, + "Fails with TSAN with the following error: starting new threads after multi-threaded " + "fork is not supported. Dying (set die_after_fork=0 to override)") +class TestDataLoader2_EventLoop(TestCase): + @skipIfNoDill + def test_basic_threading(self): + def clean_me(process, req_queue, res_queue): + req_queue.put(communication.messages.TerminateRequest()) + _ = res_queue.get() + process.join() + + it = list(range(100)) + numbers_dp = IterableWrapper(it) + (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) + + process.start() + local_datapipe = communication.iter.QueueWrapper( + communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) + + actual = list(local_datapipe) + clean_me(process, req_queue, res_queue) + + self.assertEqual(list(range(100)), actual) + class StringDataset(Dataset): def __init__(self): self.s = '12345' @@ -2253,9 +2320,24 @@ def _run_ind_worker_queue_test(self, batch_size, num_workers): current_worker_idx = 0 def test_ind_worker_queue(self): + max_num_workers = None + if hasattr(os, 'sched_getaffinity'): + try: + max_num_workers = len(os.sched_getaffinity(0)) + except Exception: + pass + if max_num_workers is None: + cpu_count = os.cpu_count() + if cpu_count is not None: + # Use half number of CPUs + max_num_workers = cpu_count // 2 + + if max_num_workers is None: + max_num_workers = 1 + for batch_size in (8, 16, 32, 64): - for num_workers in range(1, 6): - self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers) + for num_workers in range(0, min(6, max_num_workers)): + self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers + 1) class SetAffinityDataset(IterableDataset): diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 9a7876e334639..15cb05986b518 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -34,7 +34,6 @@ import numpy as np import torch -import torch.nn as nn import torch.utils.data.backward_compatibility import torch.utils.data.datapipes as dp import torch.utils.data.graph @@ -54,13 +53,6 @@ basichandlers as decoder_basichandlers, ) -try: - import torchvision.transforms - HAS_TORCHVISION = True -except ImportError: - HAS_TORCHVISION = False -skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") - try: import dill # XXX: By default, dill writes the Pickler dispatch table to inject its @@ -110,14 +102,54 @@ def create_temp_dir_and_files(): class TestDataChunk(TestCase): + def setUp(self): + self.elements = list(range(10)) + random.shuffle(self.elements) + self.chunk: DataChunk[int] = DataChunk(self.elements) + + def test_getitem(self): + for i in range(10): + self.assertEqual(self.elements[i], self.chunk[i]) + + def test_iter(self): + for ele, dc in zip(self.elements, iter(self.chunk)): + self.assertEqual(ele, dc) + + def test_len(self): + self.assertEqual(len(self.elements), len(self.chunk)) + def test_as_string(self): + self.assertEqual(str(self.chunk), str(self.elements)) + + batch = [self.elements] * 3 + chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3 + self.assertEqual(str(batch), str(chunks)) + + def test_sort(self): + chunk: DataChunk[int] = DataChunk(self.elements) + chunk.sort() + self.assertTrue(isinstance(chunk, DataChunk)) + for i, d in enumerate(chunk): + self.assertEqual(i, d) + + def test_reverse(self): + chunk: DataChunk[int] = DataChunk(self.elements) + chunk.reverse() + self.assertTrue(isinstance(chunk, DataChunk)) + for i in range(10): + self.assertEqual(chunk[i], self.elements[9 - i]) + + def test_random_shuffle(self): elements = list(range(10)) chunk: DataChunk[int] = DataChunk(elements) - self.assertEqual(str(chunk), str(elements)) - batch = [elements] * 3 - chunks: List[DataChunk] = [DataChunk(elements)] * 3 - self.assertEqual(str(chunk), str(elements)) + rng = random.Random(0) + rng.shuffle(chunk) + + rng = random.Random(0) + rng.shuffle(elements) + + self.assertEqual(chunk, elements) class TestIterableDataPipeBasic(TestCase): @@ -137,7 +169,7 @@ def tearDown(self): def test_listdirfiles_iterable_datapipe(self): temp_dir = self.temp_dir.name - datapipe = dp.iter.ListDirFiles(temp_dir, '') + datapipe = dp.iter.FileLister(temp_dir, '') count = 0 for pathname in datapipe: @@ -146,7 +178,7 @@ def test_listdirfiles_iterable_datapipe(self): self.assertEqual(count, len(self.temp_files)) count = 0 - datapipe = dp.iter.ListDirFiles(temp_dir, '', recursive=True) + datapipe = dp.iter.FileLister(temp_dir, '', recursive=True) for pathname in datapipe: count = count + 1 self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files)) @@ -155,13 +187,13 @@ def test_listdirfiles_iterable_datapipe(self): def test_loadfilesfromdisk_iterable_datapipe(self): # test import datapipe class directly from torch.utils.data.datapipes.iter import ( - ListDirFiles, - LoadFilesFromDisk, + FileLister, + FileLoader, ) temp_dir = self.temp_dir.name - datapipe1 = ListDirFiles(temp_dir, '') - datapipe2 = LoadFilesFromDisk(datapipe1) + datapipe1 = FileLister(temp_dir, '') + datapipe2 = FileLoader(datapipe1) count = 0 for rec in datapipe2: @@ -180,9 +212,9 @@ def test_readfilesfromtar_iterable_datapipe(self): tar.add(self.temp_files[0]) tar.add(self.temp_files[1]) tar.add(self.temp_files[2]) - datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar') - datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) - datapipe3 = dp.iter.ReadFilesFromTar(datapipe2) + datapipe1 = dp.iter.FileLister(temp_dir, '*.tar') + datapipe2 = dp.iter.FileLoader(datapipe1) + datapipe3 = dp.iter.TarArchiveReader(datapipe2) # read extracted files before reaching the end of the tarfile for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files): self.assertTrue(rec is not None and temp_file is not None) @@ -207,9 +239,9 @@ def test_readfilesfromzip_iterable_datapipe(self): myzip.write(self.temp_files[0]) myzip.write(self.temp_files[1]) myzip.write(self.temp_files[2]) - datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.zip') - datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) - datapipe3 = dp.iter.ReadFilesFromZip(datapipe2) + datapipe1 = dp.iter.FileLister(temp_dir, '*.zip') + datapipe2 = dp.iter.FileLoader(datapipe1) + datapipe3 = dp.iter.ZipArchiveReader(datapipe2) # read extracted files before reaching the end of the zipfile for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files): self.assertTrue(rec is not None and temp_file is not None) @@ -231,8 +263,8 @@ def test_routeddecoder_iterable_datapipe(self): temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png") png_data = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single) np.save(temp_pngfile_pathname, png_data) - datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt']) - datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) + datapipe1 = dp.iter.FileLister(temp_dir, ['*.png', '*.txt']) + datapipe2 = dp.iter.FileLoader(datapipe1) def _png_decoder(extension, data): if extension != 'png': @@ -267,7 +299,7 @@ def _helper(prior_dp, dp, channel_first=False): _helper(cached, datapipe4, channel_first=True) # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate - def test_groupbykey_iterable_datapipe(self): + def test_groupby_iterable_datapipe(self): temp_dir = self.temp_dir.name temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar") file_list = [ @@ -281,16 +313,28 @@ def test_groupbykey_iterable_datapipe(self): f.write('12345abcde') tar.add(file_pathname) - datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar') - datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) - datapipe3 = dp.iter.ReadFilesFromTar(datapipe2) - datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2) + datapipe1 = dp.iter.FileLister(temp_dir, '*.tar') + datapipe2 = dp.iter.FileLoader(datapipe1) + datapipe3 = dp.iter.TarArchiveReader(datapipe2) + + def group_fn(data): + filepath, _ = data + return os.path.basename(filepath).split(".")[0] - expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), ( - "f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")] + datapipe4 = dp.iter.Grouper(datapipe3, group_key_fn=group_fn, group_size=2) + + def order_fn(data): + data.sort(key=lambda f: f[0], reverse=True) + return data + + datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn) # type: ignore[var-annotated] + + expected_result = [ + ("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), + ("f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.txt", "h.json")] count = 0 - for rec, expected in zip(datapipe4, expected_result): + for rec, expected in zip(datapipe5, expected_result): count = count + 1 self.assertEqual(os.path.basename(rec[0][0]), expected[0]) self.assertEqual(os.path.basename(rec[1][0]), expected[1]) @@ -310,6 +354,15 @@ def test_demux_mux_datapipe(self): n = n1.mux(n2, n3) self.assertEqual(list(range(10)), list(n)) + # Test Case: Uneven DataPipes + source_numbers = list(range(0, 10)) + [10, 12] + numbers_dp = IDP(source_numbers) + n1, n2 = numbers_dp.demux(2, lambda x: x % 2) + self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1)) + self.assertEqual([1, 3, 5, 7, 9], list(n2)) + n = n1.mux(n2) + self.assertEqual(source_numbers, list(n)) + class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): def __init__(self, *args, logfile=None, **kwargs): @@ -407,13 +460,14 @@ def _get_data_from_tuple_fn(data, *args, **kwargs): create_temp_files_for_serving(tmpdir, test_file_count, test_file_size, file_url_template) - datapipe_dir_f = dp.iter.ListDirFiles(tmpdir, '*_list') - datapipe_f_lines = dp.iter.ReadLinesFromFile(datapipe_dir_f) + datapipe_dir_f = dp.iter.FileLister(tmpdir, '*_list') + datapipe_stream = dp.iter.FileLoader(datapipe_dir_f) + datapipe_f_lines = dp.iter.LineReader(datapipe_stream) datapipe_line_url: IterDataPipe[str] = \ - dp.iter.Map(datapipe_f_lines, _get_data_from_tuple_fn, (1,)) + dp.iter.Mapper(datapipe_f_lines, _get_data_from_tuple_fn, (1,)) datapipe_http = dp.iter.HttpReader(datapipe_line_url, timeout=timeout) - datapipe_tob = dp.iter.ToBytes(datapipe_http, chunk=chunk) + datapipe_tob = dp.iter.StreamReader(datapipe_http, chunk=chunk) for (url, data) in datapipe_tob: self.assertGreater(len(url), 0) @@ -499,18 +553,18 @@ class TestFunctionalIterDataPipe(TestCase): def _test_picklable(self): arr = range(10) picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ - (dp.iter.Map, IDP(arr), (), {}), - (dp.iter.Map, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), - (dp.iter.Collate, IDP(arr), (), {}), - (dp.iter.Collate, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), + (dp.iter.Mapper, IDP(arr), (), {}), + (dp.iter.Mapper, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), + (dp.iter.Collator, IDP(arr), (), {}), + (dp.iter.Collator, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), (dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ - (dp.iter.Map, IDP(arr), (lambda x: x, ), {}), - (dp.iter.Collate, IDP(arr), (lambda x: x, ), {}), + (dp.iter.Mapper, IDP(arr), (lambda x: x, ), {}), + (dp.iter.Collator, IDP(arr), (lambda x: x, ), {}), (dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: @@ -526,10 +580,10 @@ def test_concat_datapipe(self): input_dp2 = IDP(range(5)) with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): - dp.iter.Concat() + dp.iter.Concater() with self.assertRaisesRegex(TypeError, r"Expected all inputs to be `IterDataPipe`"): - dp.iter.Concat(input_dp1, ()) # type: ignore[arg-type] + dp.iter.Concater(input_dp1, ()) # type: ignore[arg-type] concat_dp = input_dp1.concat(input_dp2) self.assertEqual(len(concat_dp), 15) @@ -546,6 +600,217 @@ def test_concat_datapipe(self): self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) + + def test_fork_datapipe(self): + input_dp = IDP(range(10)) + + # Test Case: making sure all child DataPipe shares the same reference + dp1, dp2, dp3 = input_dp.fork(num_instances=3) + self.assertTrue(all(n1 is n2 for n1, n2 in zip(dp1, dp2))) + self.assertTrue(all(n1 is n3 for n1, n3 in zip(dp1, dp3))) + + # Test Case: one child DataPipe yields all value at a time + output1, output2, output3 = list(dp1), list(dp2), list(dp3) + self.assertEqual(list(range(10)), output1) + self.assertEqual(list(range(10)), output2) + self.assertEqual(list(range(10)), output3) + + # Test Case: two child DataPipes yield value together + dp1, dp2 = input_dp.fork(num_instances=2) + output = [] + for n1, n2 in zip(dp1, dp2): + output.append((n1, n2)) + self.assertEqual([(i, i) for i in range(10)], output) + + # Test Case: one child DataPipe yields all value first, but buffer_size = 5 being too small + dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) + it1 = iter(dp1) + for _ in range(5): + next(it1) + with self.assertRaises(BufferError): + next(it1) + + # Test Case: two child DataPipes yield value together with buffer size 1 + dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1) + output = [] + for n1, n2 in zip(dp1, dp2): + output.append((n1, n2)) + self.assertEqual([(i, i) for i in range(10)], output) + + # Test Case: make sure logic related to slowest_ptr is working properly + dp1, dp2, dp3 = input_dp.fork(num_instances=3) + output1, output2 , output3 = [], [], [] + for i, (n1, n2) in enumerate(zip(dp1, dp2)): + output1.append(n1) + output2.append(n2) + if i == 4: # yield all of dp3 when halfway through dp1, dp2 + output3 = list(dp3) + break + self.assertEqual(list(range(5)), output1) + self.assertEqual(list(range(5)), output2) + self.assertEqual(list(range(10)), output3) + + # Test Case: DataPipe doesn't reset if this pipe hasn't been read + dp1, dp2 = input_dp.fork(num_instances=2) + i1, i2 = iter(dp1), iter(dp2) + output2 = [] + for i, n2 in enumerate(i2): + output2.append(n2) + if i == 4: + i1 = iter(dp1) # Doesn't reset because i1 hasn't been read + self.assertEqual(list(range(10)), output2) + + # Test Case: DataPipe reset when some of it have been read + dp1, dp2 = input_dp.fork(num_instances=2) + i1, i2 = iter(dp1), iter(dp2) + output1, output2 = [], [] + for i, (n1, n2) in enumerate(zip(i1, i2)): + output1.append(n1) + output2.append(n2) + if i == 4: + with warnings.catch_warnings(record=True) as wa: + i1 = iter(dp1) # Reset both all child DataPipe + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + self.assertEqual(list(range(5)) + list(range(10)), output1) + self.assertEqual(list(range(5)) + list(range(10)), output2) + + # Test Case: DataPipe reset, even when some other child DataPipes are not read + dp1, dp2, dp3 = input_dp.fork(num_instances=3) + output1, output2 = list(dp1), list(dp2) + self.assertEqual(list(range(10)), output1) + self.assertEqual(list(range(10)), output2) + output1, output2 = list(dp1), list(dp2) + with warnings.catch_warnings(record=True) as wa: + self.assertEqual(list(range(10)), list(dp1)) # Resets even though dp3 has not been read + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + output3 = [] + for i, n3 in enumerate(dp3): + output3.append(n3) + if i == 4: + with warnings.catch_warnings(record=True) as wa: + output1 = list(dp1) # Resets even though dp3 is only partially read + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + self.assertEqual(list(range(5)), output3) + self.assertEqual(list(range(10)), output1) + break + self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again + + # Test Case: Each DataPipe inherits the source datapipe's length + dp1, dp2, dp3 = input_dp.fork(num_instances=3) + self.assertEqual(len(input_dp), len(dp1)) + self.assertEqual(len(input_dp), len(dp2)) + self.assertEqual(len(input_dp), len(dp3)) + + + def test_demux_datapipe(self): + input_dp = IDP(range(10)) + + # Test Case: split into 2 DataPipes and output them one at a time + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + output1, output2 = list(dp1), list(dp2) + self.assertEqual(list(range(0, 10, 2)), output1) + self.assertEqual(list(range(1, 10, 2)), output2) + + # Test Case: split into 2 DataPipes and output them together + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + output = [] + for n1, n2 in zip(dp1, dp2): + output.append((n1, n2)) + self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output) + + # Test Case: values of the same classification are lumped together, and buffer_size = 3 being too small + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4) + it1 = iter(dp1) + with self.assertRaises(BufferError): + next(it1) # Buffer raises because first 5 elements all belong to the a different child + + # Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5) + output1, output2 = list(dp1), list(dp2) + self.assertEqual(list(range(5, 10)), output1) + self.assertEqual(list(range(0, 5)), output2) + + # Test Case: classifer returns a value outside of [0, num_instance - 1] + dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2) + it = iter(dp[0]) + with self.assertRaises(ValueError): + next(it) + next(it) + + # Test Case: DataPipe doesn't reset when it has not been read + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + i1 = iter(dp1) + output2 = [] + i = 0 + for i, n2 in enumerate(dp2): + output2.append(n2) + if i == 4: + i1 = iter(dp1) + self.assertEqual(list(range(1, 10, 2)), output2) + + # Test Case: DataPipe reset when some of it has been read + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + output1, output2 = [], [] + for n1, n2 in zip(dp1, dp2): + output1.append(n1) + output2.append(n2) + if n1 == 4: + break + with warnings.catch_warnings(record=True) as wa: + i1 = iter(dp1) # Reset all child DataPipes + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + for n1, n2 in zip(dp1, dp2): + output1.append(n1) + output2.append(n2) + self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1) + self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2) + + # Test Case: DataPipe reset, even when not all child DataPipes are exhausted + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + output1 = list(dp1) + self.assertEqual(list(range(0, 10, 2)), output1) + with warnings.catch_warnings(record=True) as wa: + self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Reset even when dp2 is not read + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + output2 = [] + for i, n2 in enumerate(dp2): + output2.append(n2) + if i == 1: + self.assertEqual(list(range(1, 5, 2)), output2) + with warnings.catch_warnings(record=True) as wa: + self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Can reset even when dp2 is partially read + self.assertEqual(len(wa), 1) + self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") + break + output2 = list(dp2) # output2 has to read from beginning again + self.assertEqual(list(range(1, 10, 2)), output2) + + # Test Case: drop_none = True + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, + drop_none=True) + self.assertEqual([2, 4, 6, 8], list(dp1)) + self.assertEqual([1, 3, 7, 9], list(dp2)) + + # Test Case: drop_none = False + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, + drop_none=False) + it1 = iter(dp1) + with self.assertRaises(ValueError): + next(it1) + + # Test Case: __len__ not implemented + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + with self.assertRaises(TypeError): + len(dp1) # It is not implemented as we do not know length for each child in advance + with self.assertRaises(TypeError): + len(dp2) + + def test_map_datapipe(self): input_dp = IDP(range(10)) @@ -768,7 +1033,7 @@ def _filter_fn(data, val, clip=False): for data, exp in zip(filter_dp, range(5, 10)): self.assertEqual(data, exp) - with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): + with self.assertRaisesRegex(TypeError, r"has no len"): len(filter_dp) def _non_bool_fn(data): @@ -873,59 +1138,17 @@ def test_shuffle_datapipe(self): with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): len(shuffle_dp_nl) - @skipIfNoTorchVision - def test_transforms_datapipe(self): - torch.set_default_dtype(torch.float) - # A sequence of numpy random numbers representing 3-channel images - w = h = 32 - inputs = [np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) for i in range(10)] - tensor_inputs = [torch.tensor(x, dtype=torch.float).permute(2, 0, 1) / 255. for x in inputs] - - input_dp = IDP(inputs) - # Raise TypeError for python function - with self.assertRaisesRegex(TypeError, r"`transforms` are required to be"): - input_dp.legacy_transforms(_fake_fn) - - # transforms.Compose of several transforms - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Pad(1, fill=1, padding_mode='constant'), - ]) - tsfm_dp = input_dp.legacy_transforms(transforms) - self.assertEqual(len(tsfm_dp), len(input_dp)) - for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): - self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data) - - # nn.Sequential of several transforms (required to be instances of nn.Module) - input_dp = IDP(tensor_inputs) - transforms = nn.Sequential( - torchvision.transforms.Pad(1, fill=1, padding_mode='constant'), - ) - tsfm_dp = input_dp.legacy_transforms(transforms) - self.assertEqual(len(tsfm_dp), len(input_dp)) - for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): - self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data) - - # Single transform - input_dp = IDP_NoLen(inputs) # type: ignore[assignment] - transform = torchvision.transforms.ToTensor() - tsfm_dp = input_dp.legacy_transforms(transform) - with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): - len(tsfm_dp) - for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): - self.assertEqual(tsfm_data, input_data) - def test_zip_datapipe(self): with self.assertRaises(TypeError): - dp.iter.Zip(IDP(range(10)), list(range(10))) # type: ignore[arg-type] + dp.iter.Zipper(IDP(range(10)), list(range(10))) # type: ignore[arg-type] - zipped_dp = dp.iter.Zip(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated] + zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated] with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): len(zipped_dp) exp = list((i, i) for i in range(5)) self.assertEqual(list(zipped_dp), exp) - zipped_dp = dp.iter.Zip(IDP(range(10)), IDP(range(5))) + zipped_dp = dp.iter.Zipper(IDP(range(10)), IDP(range(5))) self.assertEqual(len(zipped_dp), 5) self.assertEqual(list(zipped_dp), exp) # Reset @@ -939,8 +1162,8 @@ def _test_picklable(self): picklable_datapipes: List[ Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] ] = [ - (dp.map.Map, MDP(arr), (), {}), - (dp.map.Map, MDP(arr), (_fake_fn, (0,), {'test': True}), {}), + (dp.map.Mapper, MDP(arr), (), {}), + (dp.map.Mapper, MDP(arr), (_fake_fn, (0,), {'test': True}), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] @@ -948,7 +1171,7 @@ def _test_picklable(self): unpicklable_datapipes: List[ Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] ] = [ - (dp.map.Map, MDP(arr), (lambda x: x,), {}), + (dp.map.Mapper, MDP(arr), (lambda x: x,), {}), ] for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: with warnings.catch_warnings(record=True) as wa: @@ -965,10 +1188,10 @@ def test_concat_datapipe(self): input_dp2 = MDP(range(5)) with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): - dp.map.Concat() + dp.map.Concater() with self.assertRaisesRegex(TypeError, r"Expected all inputs to be `MapDataPipe`"): - dp.map.Concat(input_dp1, ()) # type: ignore[arg-type] + dp.map.Concater(input_dp1, ()) # type: ignore[arg-type] concat_dp = input_dp1.concat(input_dp2) self.assertEqual(len(concat_dp), 15) @@ -1007,6 +1230,39 @@ def fn(item, dtype=torch.float, *, sum=False): map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum() ) + def test_mux_datapipe(self): + + # Test Case: Elements are yielded one at a time from each DataPipe, until they are all exhausted + input_dp1 = IDP(range(4)) + input_dp2 = IDP(range(4, 8)) + input_dp3 = IDP(range(8, 12)) + output_dp = input_dp1.mux(input_dp2, input_dp3) + expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] + self.assertEqual(len(expected_output), len(output_dp)) + self.assertEqual(expected_output, list(output_dp)) + + # Test Case: Uneven input Data Pipes + input_dp1 = IDP([1, 2, 3, 4]) + input_dp2 = IDP([10]) + input_dp3 = IDP([100, 200, 300]) + output_dp = input_dp1.mux(input_dp2, input_dp3) + expected_output = [1, 10, 100, 2, 200, 3, 300, 4] + self.assertEqual(len(expected_output), len(output_dp)) + self.assertEqual(expected_output, list(output_dp)) + + # Test Case: Empty Data Pipe + input_dp1 = IDP([0, 1, 2, 3]) + input_dp2 = IDP([]) + output_dp = input_dp1.mux(input_dp2) + self.assertEqual(len(input_dp1), len(output_dp)) + self.assertEqual(list(input_dp1), list(output_dp)) + + # Test Case: raises TypeError when __len__ is called and an input doesn't have __len__ + input_dp1 = IDP(range(10)) + input_dp_no_len = IDP_NoLen(range(10)) + output_dp = input_dp1.mux(input_dp_no_len) + with self.assertRaises(TypeError): + len(output_dp) # Metaclass conflict for Python 3.6 # Multiple inheritance with NamedTuple is not supported for Python 3.9 @@ -1330,24 +1586,25 @@ def test_simple_traverse(self): expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}} self.assertEqual(expected, graph) - # TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature - # of the fork fake implementation, update fork first and fix this test too @skipIfNoDill def test_traverse_forked(self): numbers_dp = NumbersDataset(size=50) - dp0, dp1, dp2 = numbers_dp.fork(3) + dp0, dp1, dp2 = numbers_dp.fork(num_instances=3) dp0_upd = dp0.map(lambda x: x * 10) dp1_upd = dp1.filter(lambda x: x % 3 == 1) combined_dp = dp0_upd.mux(dp1_upd, dp2) graph = torch.utils.data.graph.traverse(combined_dp) - expected = {combined_dp: {dp0_upd: {dp0: {}}, dp1_upd: {dp1: {}}, dp2: {}}} + expected = {combined_dp: {dp0_upd: {dp0: {dp0.main_datapipe: {dp0.main_datapipe.main_datapipe: {}}}}, + dp1_upd: {dp1: {dp1.main_datapipe: {dp1.main_datapipe.main_datapipe: {}}}}, + dp2: {dp2.main_datapipe: {dp2.main_datapipe.main_datapipe: {}}}}} self.assertEqual(expected, graph) class TestSharding(TestCase): + def _get_pipeline(self): numbers_dp = NumbersDataset(size=10) - dp0, dp1 = numbers_dp.fork(2) + dp0, dp1 = numbers_dp.fork(num_instances=2) dp0_upd = dp0.map(lambda x: x * 10) dp1_upd = dp1.filter(lambda x: x % 3 == 1) combined_dp = dp0_upd.mux(dp1_upd) @@ -1369,6 +1626,27 @@ def test_simple_sharding(self): self.assertEqual(sorted(all_items), sorted(items)) + def test_sharding_length(self): + numbers_dp = IDP(range(13)) + sharded_dp0 = numbers_dp.sharding_filter() + torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0) + sharded_dp1 = numbers_dp.sharding_filter() + torch.utils.data.sharding.apply_sharding(sharded_dp1, 3, 1) + sharded_dp2 = numbers_dp.sharding_filter() + torch.utils.data.sharding.apply_sharding(sharded_dp2, 3, 2) + self.assertEqual(13, len(numbers_dp)) + self.assertEqual(5, len(sharded_dp0)) + self.assertEqual(4, len(sharded_dp1)) + self.assertEqual(4, len(sharded_dp2)) + + numbers_dp = IDP(range(1)) + sharded_dp0 = numbers_dp.sharding_filter() + torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0) + sharded_dp1 = numbers_dp.sharding_filter() + torch.utils.data.sharding.apply_sharding(sharded_dp1, 2, 1) + self.assertEqual(1, len(sharded_dp0)) + self.assertEqual(0, len(sharded_dp1)) + @skipIfNoDill def test_old_dataloader(self): dp = self._get_pipeline() diff --git a/test/test_determination.py b/test/test_determination.py index 6d338af4b6c8f..ca00835429c4c 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -16,7 +16,6 @@ class DeterminationTest(unittest.TestCase): "test_jit_profiling", "test_jit", "test_torch", - "distributed/test_distributed_fork", "distributed/test_distributed_spawn", "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", @@ -31,9 +30,16 @@ def determined_tests(cls, changed_files): return [ test for test in cls.TESTS - if run_test.determine_target(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions()) + if run_test.should_run_test(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions()) ] + def test_target_det_list_is_sorted(self): + # We keep TARGET_DET_LIST sorted to minimize merge conflicts + # but most importantly to allow us to comment on the absence + # of a test. It would be very difficult to add a file right + # next to a comment that says to keep it out of the list. + self.assertListEqual(run_test.TARGET_DET_LIST, sorted(run_test.TARGET_DET_LIST)) + def test_config_change_only(self): """CI configs trigger all tests""" self.assertEqual( @@ -104,7 +110,6 @@ def test_torch_file(self): self.assertEqual( self.determined_tests(["torch/utils/cpp_extension.py"]), [ - "distributed/test_distributed_fork", "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", "test_utils", diff --git a/test/test_foreach.py b/test/test_foreach.py index ce9b0d7ee55e3..c6cf1302ffb5c 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -4,11 +4,16 @@ import re import torch import unittest + +from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCUDA, skipCUDAIfRocm, skipMeta, ops) from torch.testing._internal.common_methods_invocations import \ - (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db, make_tensor) + (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db) +from torch.testing._internal.common_dtype import ( + get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, +) # Includes some values such that N * N won't be a multiple of 4, # which should ensure we test the vectorized and non-vectorized @@ -131,7 +136,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True) if opinfo.supports_alpha_param: alpha = None - if dtype in torch.testing.get_all_int_dtypes(): + if dtype in get_all_int_dtypes(): alpha = 3 elif dtype.is_complex: alpha = complex(3, 3) @@ -168,7 +173,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis @ops(foreach_binary_op_db) def test_binary_op_tensorlists_fastpath(self, device, dtype, op): for N in N_values: - disable_fastpath = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] if op.ref == torch.add and dtype == torch.bool: disable_fastpath = True self._test_binary_op_tensorlists(device, dtype, op, N, True, disable_fastpath) @@ -190,17 +195,17 @@ def _test_binary_op_scalar(self, device, dtype, opinfo, N, scalar, is_fastpath, @ops(foreach_binary_op_db) def test_binary_op_scalar_fastpath(self, device, dtype, op): for N, scalar in itertools.product(N_values, Scalars): - disable_fastpath = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] if isinstance(scalar, int): disable_fastpath |= dtype == torch.bool if isinstance(scalar, float): - disable_fastpath |= dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] if isinstance(scalar, bool): disable_fastpath |= dtype == torch.bool if op.ref in (torch.add, torch.mul): disable_fastpath = False if isinstance(scalar, complex): - disable_fastpath |= dtype not in torch.testing.get_all_complex_dtypes() + disable_fastpath |= dtype not in get_all_complex_dtypes() self._test_binary_op_scalar(device, dtype, op, N, scalar, True, disable_fastpath) @ops(foreach_binary_op_db) @@ -230,16 +235,16 @@ def _test_binary_op_scalarlist(self, device, dtype, opinfo, N, scalarlist, is_fa def test_binary_op_scalarlist_fastpath(self, device, dtype, op): for N in N_values: for type_str, scalarlist in getScalarLists(N): - bool_int_div = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] disable_fastpath = bool_int_div if type_str == "int": disable_fastpath |= dtype == torch.bool if type_str == "float": - disable_fastpath |= dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] if type_str == "complex": - disable_fastpath |= dtype not in torch.testing.get_all_complex_dtypes() + disable_fastpath |= dtype not in get_all_complex_dtypes() if type_str == "mixed": - disable_fastpath |= True and dtype not in torch.testing.get_all_complex_dtypes() + disable_fastpath |= True and dtype not in get_all_complex_dtypes() self._test_binary_op_scalarlist(device, dtype, op, N, scalarlist, True, disable_fastpath) @ops(foreach_binary_op_db) @@ -296,7 +301,7 @@ def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fast @skipMeta @ops(foreach_pointwise_op_db) def test_pointwise_op_fastpath(self, device, dtype, op): - disable_fastpath = dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool] # for N, scalar in itertools.product(N_values, Scalars): for N in N_values: self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath) @@ -354,7 +359,7 @@ def _test_unary(self, device, dtype, opinfo, N, is_fastpath): op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1) inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), # note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath. - if opinfo.name == "_foreach_abs" and dtype in torch.testing.get_all_complex_dtypes(): + if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes(): is_fastpath = False self._regular_unary_test(dtype, op, ref, inputs, is_fastpath) self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath) @@ -365,7 +370,7 @@ def test_unary_fastpath(self, device, dtype, op): for N in N_values: self._test_unary(device, dtype, op, N, is_fastpath=True) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) @ops(foreach_unary_op_db) def test_unary_slowpath(self, device, dtype, op): for N in N_values: @@ -376,14 +381,14 @@ def _minmax_test(self, opinfo, inputs, is_fastpath, n_expected_cudaLaunchKernels self.assertEqual(ref(inputs), op(inputs, self.is_cuda, is_fastpath)) # note(mkozuki): in-place of foreach_minimum and foreach_maximum aren't implemented. - # @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False)) + # @dtypes(*get_all_dtypes(include_bfloat16=False, include_complex=False)) @ops(foreach_minmax_op_db) def test_minmax_fastpath(self, device, dtype, op): for N in N_values: inputs = tuple(op.sample_inputs(device, dtype, N) for _ in range(2)) self._minmax_test(op, inputs, True, N if dtype == torch.bool else 1) - @dtypes(*torch.testing.get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) + @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) @ops(foreach_minmax_op_db) def test_minmax_slowpath(self, device, dtype, op): for N in N_values: @@ -392,7 +397,7 @@ def test_minmax_slowpath(self, device, dtype, op): # note(mkozuki): ForeachFuncInfo's of both `_foreach_maximum` and `_foreach_minimum` include integer types. # so, manually limit dtypes to fp types for inf&nan tests. - @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=True, include_half=True)) + @dtypes(*get_all_fp_dtypes(include_bfloat16=True, include_half=True)) @ops(foreach_minmax_op_db) def test_minmax_float_inf_nan(self, device, dtype, op): inputs = ( @@ -411,7 +416,7 @@ def test_minmax_float_inf_nan(self, device, dtype, op): ) self._minmax_test(op, inputs, True, 1) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): # TODO: enable empty list case for tensors in [[torch.randn([0])]]: @@ -421,7 +426,7 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): torch._foreach_add_(tensors, 1) self.assertEqual(res, tensors) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) @ops(foreach_binary_op_db) def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op): foreach_op, ref = op.method_variant, op.ref @@ -455,7 +460,7 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): runtime_error = e self.assertIsNone(runtime_error) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) @ops(foreach_binary_op_db) def test_binary_op_list_error_cases(self, device, dtype, op): foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace @@ -511,7 +516,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op): return with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): foreach_op([tensor1], [tensor2]) - if dtype in torch.testing.get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div: + if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div: with self.assertRaisesRegex(RuntimeError, "result type"): foreach_op_([tensor1], [tensor2]) else: @@ -520,7 +525,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op): @skipMeta @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) @ops(foreach_binary_op_db) def test_binary_op_list_slow_path(self, device, dtype, op): # note(mkozuki): why `n_expected_cudaLaunchKernels=0`? @@ -613,7 +618,7 @@ def test_binary_op_tensors_on_different_devices(self, device, dtype, op): self.assertEqual(actual, tensors1) @onlyCUDA - @dtypes(*torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)) + @dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) @ops(foreach_pointwise_op_db) def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # tensors1: ['cuda', 'cpu] diff --git a/test/test_function_schema.py b/test/test_function_schema.py index 0451debebd196..7c7a0f77cb922 100644 --- a/test/test_function_schema.py +++ b/test/test_function_schema.py @@ -86,6 +86,27 @@ def test_backward_compatible_arguments(self): new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + def test_backward_compatible_with_smart_serialization(self): + # cases where out arg is provided + old_schema = parse_schema('foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)') + new_schema_same_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)') + new_schema_wrong_default = parse_schema('foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)') + new_schema_more_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)') + new_schema_wrong_pos = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)') + self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema)) + self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema)) + self.assertFalse(new_schema_wrong_default.is_backward_compatible_with(old_schema)) + self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema)) + + # cases where out arg is not provided + old_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1) -> int') + new_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1, int c=2) -> int') + new_schema_without_arg_multiple_default = parse_schema('foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int') + new_schema_without_arg_wrong_pos = parse_schema('foo(Tensor self, int a, int c=2, int b=1) -> int') + self.assertTrue(new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg)) + self.assertTrue(new_schema_without_arg_multiple_default.is_backward_compatible_with(old_schema_without_arg)) + self.assertFalse(new_schema_without_arg_wrong_pos.is_backward_compatible_with(old_schema_without_arg)) + def test_string_optional_parameter_default_value(self): schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)") schema_b = parse_schema(str(schema_a)) diff --git a/test/test_functional_optim.py b/test/test_functional_optim.py index c37823427fc1d..accc72058578d 100644 --- a/test/test_functional_optim.py +++ b/test/test_functional_optim.py @@ -1,19 +1,9 @@ -import unittest - import torch import torch.nn as nn import torch.nn.functional as F -from torch.optim import SGD, Adam -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS - -if not IS_WINDOWS: - from torch.distributed.optim.functional_sgd import _FunctionalSGD - from torch.distributed.optim.functional_adam import _FunctionalAdam - _SUPPORTED_OPTIM_MAPPING = { - SGD: _FunctionalSGD, - Adam: _FunctionalAdam - } - +from torch.optim import SGD, Adam, AdamW +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.distributed.optim import functional_optim_map class MyModule(torch.nn.Module): def __init__(self): @@ -37,7 +27,7 @@ def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): optim_params = module_optim.parameters() functional_params = module_functional.parameters() optim = optim_cls(optim_params, *args, **kwargs) - functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None) + functional_optim_cls = functional_optim_map.get(optim_cls, None) if not functional_optim_cls: raise ValueError(f"Functional optimizer not implemented for {optim_cls}") optim_functional = functional_optim_cls( @@ -88,20 +78,15 @@ def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): self.assertNotEqual(old_module_optim_params[i], optim_param) self.assertNotEqual(old_module_functional_params[i], functional_param) - @unittest.skipIf( - IS_WINDOWS, - "Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137", - ) def test_functional_optim_parity_sgd(self): self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01) - @unittest.skipIf( - IS_WINDOWS, - "Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137", - ) def test_functional_optim_parity_adam(self): self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6) + def test_functional_optim_parity_adam_w(self): + self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6) + if __name__ == "__main__": run_tests() diff --git a/test/test_fx.py b/test/test_fx.py index f0a3291d07d4a..57a2960a409c3 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -11,6 +11,8 @@ import sys import torch import traceback +import typing +import types import warnings import unittest from math import sqrt @@ -31,6 +33,7 @@ from collections import namedtuple from torch.fx.proxy import TraceError +from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 from fx.test_dce_pass import TestDCE # noqa: F401 @@ -95,6 +98,8 @@ def a_lifted_leaf2(a, b): wrap('len') +wrap('getattr') + @wrap def wrapped_via_decorator(a): return a + 1 @@ -127,10 +132,17 @@ def __init__(self, a, b): class TestFX(JitTestCase): def setUp(self): - if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS: - return - lib_file_path = find_library_location('libtorchbind_test.so') - torch.ops.load_library(str(lib_file_path)) + # Checking for mutable operations whil tracing is feature flagged + # Enable it in testing but not by default + self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = True + + if not (TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS): + lib_file_path = find_library_location('libtorchbind_test.so') + torch.ops.load_library(str(lib_file_path)) + + def tearDown(self): + torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -185,6 +197,19 @@ def forward(self, A, b=4, *args, c=5, **kwargs): t = T() symbolic_trace(t) + # test for issue described at https://github.com/pytorch/pytorch/issues/63883 + class M3(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + m3 = M3() + gm3 = symbolic_trace(m3) + new_instance = gm3.__new__(type(gm3)) + new_instance.__init__(gm3, gm3.graph) + + x = torch.randn(5, 3) + torch.testing.assert_allclose(new_instance(x), torch.relu(x)) + def test_custom_import(self): graph = torch.fx.Graph() a = graph.placeholder('x') @@ -593,17 +618,17 @@ def __init__(self, interpreter): x = torch.rand(3, 4) ref_out = msm(x) test_out = lowered(x) - torch.testing.assert_allclose(test_out, ref_out) + torch.testing.assert_close(test_out, ref_out) # Test TorchScript compilation scripted_lowered = torch.jit.script(lowered) script_out = scripted_lowered(x) - torch.testing.assert_allclose(script_out, ref_out) + torch.testing.assert_close(script_out, ref_out) # Test TorchScript ser/de import_copy = self.getExportImportCopy(scripted_lowered) imported_out = import_copy(x) - torch.testing.assert_allclose(imported_out, ref_out) + torch.testing.assert_close(imported_out, ref_out) def test_reserved_getattr(self): """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" @@ -926,6 +951,14 @@ def forward(self, x): self.assertEqual(traced2(inp), inp + 3.0) self.assertIs(len, builtins.len) + def test_torch_fx_getattr(self): + class FXGetattrTest(torch.nn.Module): + def forward(self, x): + return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3])) + + traced = symbolic_trace(FXGetattrTest()) + self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3])) + def test_sqrt(self): class Sqrt1(torch.nn.Module): def forward(self, x): @@ -1280,6 +1313,12 @@ def test_wrong_topo(self): with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): graph.lint() + def test_wrong_target_type(self): + graph : torch.fx.Graph = torch.fx.Graph() + with self.assertRaises(ValueError): + n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo', + args=(), kwargs={}) + def test_example_shape_prop(self): class TestCase(torch.nn.Module): def __init__(self): @@ -1943,6 +1982,25 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): traced_graph = MyTracer().trace(CallsModWithDict()) + def test_module_deepcopy_edit_nodes(self): + class Foo(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + traced1 = symbolic_trace(Foo()) + copied = copy.deepcopy(traced1) + + for node in copied.graph.nodes: + if node.target == torch.relu: + node.target = torch.neg + + copied.recompile() + traced1.recompile() + + x = torch.randn(15, 15) + torch.testing.assert_allclose(traced1(x), torch.relu(x)) + torch.testing.assert_allclose(copied(x), torch.neg(x)) + def test_direct_param_use(self): class TransposeTest(torch.nn.Module): def __init__(self): @@ -2277,6 +2335,21 @@ def forward(self, x): r"Call using an FX-traced Module, line .* of the " r"traced Module's generated forward function:") + def test_graph_module_replicate_for_dp(self): + class Foo(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + + gm = torch.fx.symbolic_trace(Foo()) + + x = torch.randn(5, 3) + out = gm(x) + + replica = gm._replicate_for_data_parallel() + out_replica = replica(x) + + torch.testing.assert_allclose(out_replica, out) + def test_ast_rewriter_rewrites_assert(self): class M(torch.nn.Module): def forward(self, x: torch.Tensor, y: int, z: int): @@ -2301,6 +2374,19 @@ def forward(self, x: torch.Tensor, y: int, z: int): traced.graph.lint() + def test_throw_out_variant(self): + def foo(x): + y = torch.rand_like(x) + torch.sigmoid(x, out=y) + return y + + class MyTracer(torch.fx.Tracer): + check_mutable_operations = True + + tracer = MyTracer() + with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): + traced_graph = tracer.trace(foo) + def test_ast_rewriter_reassigns_submodules(self): class M(torch.nn.Module): def __init__(self): @@ -2316,6 +2402,96 @@ def forward(self, x: torch.Tensor): traced.graph.lint() + def test_ast_rewriter_wrap(self): + self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) + + def to_trace(y): + return ( + a_lifted_leaf((4, y), 3) + + a_lifted_leaf((3, 4), 5) + + a_lifted_leaf((y, y), y) + ) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(to_trace) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + self.assertIn("a_lifted_leaf", traced.code) + self.assertEqual(27, traced(2)) + self.assertIs(a_lifted_leaf, real_a_lifed_leaf) + + def test_ast_rewriter_wrap_fn_directly(self): + self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) + + def to_trace(y): + return ( + a_lifted_leaf2((4, y), 3) + + a_lifted_leaf2((3, 4), 5) + + a_lifted_leaf2((y, y), y) + ) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(to_trace) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + self.assertIn("a_lifted_leaf2", traced.code) + self.assertEqual(27, traced(2)) + self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) + + def test_ast_rewriter_wrapped_via_decorator(self): + class F(torch.nn.Module): + def forward(self, x): + return wrapped_via_decorator(x) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(F()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + self.assertIn("wrapped_via_decorator", traced.code) + self.assertEqual(traced(0), 1) + self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) + self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) + + def test_ast_rewriter_wrapped_via_decorator_and_transformed(self): + self.assertEqual(wrapped_via_decorator(0), 1) + + def to_trace(y): + return wrapped_via_decorator(y) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(to_trace) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + self.assertIn("wrapped_via_decorator", traced.code) + self.assertEqual(traced(0), 1) + self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) + self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) + + transformed = torch.fx.Transformer(traced).transform() + self.assertIn("wrapped_via_decorator", transformed.code) + self.assertEqual(transformed(0), 1) + self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) + self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) + + def test_ast_rewriter_wrap_with_submodule(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) + + def forward(self, x: torch.Tensor): + return wrapped_with_submodule(x, self.batchnorm1d) + + ast_rewriter = RewritingTracer() + graph = ast_rewriter.trace(M()) + traced = GraphModule(ast_rewriter.root, graph, "gm") + + self.assertIn("wrapped_with_submodule", traced.code) + + input = torch.rand(3, 2) + ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False) + self.assertEqual(ref_batchnorm1d(input), traced(input)) + def test_submodule_manipulation_API(self): class C(torch.nn.Module): def __init__(self): @@ -2865,6 +3041,15 @@ def run_getitem_target(): class TestOperatorSignatures(JitTestCase): + def setUp(self): + # Checking for mutable operations whil tracing is feature flagged + # Enable it in testing but not by default + self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = True + + def tearDown(self): + torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + @onlyCPU @ops(op_db, allowed_dtypes=(torch.float,)) def test_get_torch_func_signature_exhaustive(self, device, dtype, op): @@ -2930,7 +3115,264 @@ def test_get_torch_func_signature_exhaustive(self, device, dtype, op): assert op.name in known_no_schema or "nn.functional" in op.name +class TestFXAPIBackwardCompatibility(JitTestCase): + def setUp(self): + self.maxDiff = None + + # Checking for mutable operations whil tracing is feature flagged + # Enable it in testing but not by default + self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = True + + def tearDown(self): + torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + + + def _fn_to_stable_annotation_str(self, obj): + """ + Unfortunately we have to serialize function signatures manually since + serialization for `inspect.Signature` objects is not stable across + python versions + """ + fn_name = torch.typename(obj) + + signature = inspect.signature(obj) + + sig_str = f'{fn_name}{signature}' + + arg_strs = [] + for k, v in signature.parameters.items(): + maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\ + if v.annotation is not inspect.Signature.empty else '' + + def default_val_str(val): + if isinstance(val, (tuple, list)): + str_pieces = ['(' if isinstance(val, tuple) else '['] + str_pieces.append(', '.join(default_val_str(v) for v in val)) + if isinstance(val, tuple) and len(str_pieces) == 2: + str_pieces.append(',') + str_pieces.append(')' if isinstance(val, tuple) else ']') + return ''.join(str_pieces) + + # Need to fix up some default value strings. + # First case: modules. Default module `repr` contains the FS path of the module. + # Don't leak that + if isinstance(val, types.ModuleType): + return f'' + + # Second case: callables. Callables (such as lambdas) encode their address in + # their string repr. Don't do that + if callable(val): + return f'' + + return str(val) + + if v.default is not inspect.Signature.empty: + default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'" + maybe_default = f' = {default_val_str}' + else: + maybe_default = '' + maybe_stars = '' + if v.kind == inspect.Parameter.VAR_POSITIONAL: + maybe_stars = '*' + elif v.kind == inspect.Parameter.VAR_KEYWORD: + maybe_stars = '**' + arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}') + + return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\ + if signature.return_annotation is not inspect.Signature.empty else '' + + return f'{fn_name}({", ".join(arg_strs)}){return_annot}' + + def _annotation_type_to_stable_str(self, t, sig_str): + if t is inspect.Signature.empty: + return '' + + # Forward ref + if isinstance(t, str): + return f"'{t}'" + if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): + return t.__forward_arg__ + if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): + return t.__forward_arg__ + + trivial_mappings = { + str : 'str', + int : 'int', + float: 'float', + bool: 'bool', + torch.dtype: 'torch.dtype', + torch.Tensor: 'torch.Tensor', + torch.device: 'torch.device', + torch.memory_format: 'torch.memory_format', + slice: 'slice', + torch.nn.Module: 'torch.nn.modules.module.Module', + torch.fx.Graph : 'torch.fx.graph.Graph', + torch.fx.Node : 'torch.fx.node.Node', + torch.fx.Proxy : 'torch.fx.proxy.Proxy', + torch.fx.node.Target : 'torch.fx.node.Target', + torch.fx.node.Argument : 'torch.fx.node.Argument', + torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', + torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', + torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', + Ellipsis : '...', + typing.Any: 'Any', + type(None): 'NoneType', + None: 'None', + typing.Iterator: 'Iterator', + } + + mapping = trivial_mappings.get(t, None) + if mapping: + return mapping + + # Handle types with contained types + contained = getattr(t, '__args__', None) or [] + + # Callables contain a bare List for arguments + contained = t if isinstance(t, list) else contained + + # Python 3.8 puts type vars into __args__ for unbound types such as Dict + if all(isinstance(ct, typing.TypeVar) for ct in contained): + contained = [] + + contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] + contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' + + + origin = getattr(t, '__origin__', None) + if origin is None: + # Unbound types don't have `__origin__` in some Python versions, so fix that up here. + origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin + + if origin in {tuple, typing.Tuple}: + return f'Tuple{contained_type_str}' + if origin in {typing.Union}: + # Annoying hack to detect Optional + if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): + not_none_param = contained[0] if contained[0] is not type(None) else contained[1] + return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' + return f'Union{contained_type_str}' + if origin in {dict, typing.Dict}: + return f'Dict{contained_type_str}' + if origin in {list, typing.List}: + return f'List{contained_type_str}' + if origin in {type, typing.Type}: + return f'Type{contained_type_str}' + if isinstance(t, typing.Callable): + if len(contained) > 0 and contained[0] is not Ellipsis: + return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]' + else: + return f'Callable{contained_type_str}' + + raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.' + f'Please add support for this type and confirm with the ' + f'FX team that your signature change is valid.') + + + def test_function_back_compat(self): + """ + Test backward compatibility for function signatures with + @compatibility(is_backward_compatible=True). Currently this checks for + exact signature matches, which may lead to false positives. If this + becomes too annoying, we can refine this check to actually parse out + the saved schema strings and check if the change is truly backward- + incompatible. + """ + signature_strs = [] + + for obj in _BACK_COMPAT_OBJECTS: + if not isinstance(obj, type): + signature_strs.append(self._fn_to_stable_annotation_str(obj)) + + signature_strs.sort() + + try: + self.assertExpected('\n'.join(signature_strs), 'fx_backcompat_function_signatures') + except AssertionError as e: + msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \ + f"as backwards-compatible has experienced a signature change. See the " \ + f"above exception context for more information. If this change was " \ + f"unintended, please revert it. If it was intended, check with the FX " \ + f"team to ensure that the proper deprecation protocols have been followed " \ + f"and subsequently --accept the change." + raise AssertionError(msg) + + def test_class_member_back_compat(self): + """ + Test backward compatibility for members of classes with + @compatibility(is_backward_compatible=True). Currently this checks for + exact matches on the publicly visible members of the class. + """ + class_method_strs = [] + + for obj in _BACK_COMPAT_OBJECTS: + if isinstance(obj, type): + public_members = [name for name in obj.__dict__ if not name.startswith('_')] + class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}') + + class_method_strs.sort() + + try: + self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members') + except AssertionError as e: + msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \ + f"as backwards-compatible has experienced change in its public members. See the " \ + f"above exception context for more information. If this change was " \ + f"unintended, please revert it. If it was intended, check with the FX " \ + f"team to ensure that the proper deprecation protocols have been followed " \ + f"and subsequently --accept the change." + raise AssertionError(msg) + + def test_public_api_surface(self): + mod = torch.fx + + non_back_compat_objects = {} + + def check_symbols_have_bc_designation(m, prefix): + if not m.__name__.startswith('torch.fx'): + return + if m.__name__.startswith('torch.fx.experimental'): + return + for k, v in m.__dict__.items(): + if v is m: + continue + if k.startswith('_'): + continue + if isinstance(v, types.ModuleType): + check_symbols_have_bc_designation(v, prefix + [k]) + elif isinstance(v, type) or isinstance(v, types.FunctionType): + if v not in _MARKED_WITH_COMATIBLITY: + non_back_compat_objects.setdefault(v) + + check_symbols_have_bc_designation(mod, ['torch', 'fx']) + + + non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()] + # Only want objects in torch.fx + non_back_compat_strs = [ + s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')] + # Only want objects in public namespaces + non_back_compat_strs = [ + s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))] + non_back_compat_strs.sort() + + if len(non_back_compat_strs) != 0: + raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a " + f"backwards-compatibility classification! Please decorate these " + f"API(s) with `@torch.fx._compatibility.compatibility` to specify " + f"BC guarantees.") + class TestFunctionalTracing(JitTestCase): + def setUp(self): + # Checking for mutable operations whil tracing is feature flagged + # Enable it in testing but not by default + self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = True + + def tearDown(self): + torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", "has_torch_function_variadic", "handle_torch_function", "boolean_dispatch") @@ -2945,6 +3387,7 @@ class TestFunctionalTracing(JitTestCase): ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") + MUTABLE = (RuntimeError, r"Tried to trace mutable operation") UNTRACEABLE_FUNCTIONALS = { "adaptive_avg_pool1d": BUILT_IN_FUNC, @@ -3064,6 +3507,8 @@ class TestFunctionalTracing(JitTestCase): "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, + + "normalize" : MUTABLE, } # List of nn.functionals with Tensor inputs but not with type annotation @@ -3178,6 +3623,15 @@ def tearDownClass(cls): @skipIfNoTorchVision class TestVisionTracing(JitTestCase): + def setUp(self): + # Checking for mutable operations whil tracing is feature flagged + # Enable it in testing but not by default + self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations + torch.fx.proxy.TracerBase.check_mutable_operations = True + + def tearDown(self): + torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") INCONSISTENT_TYPE = ( RuntimeError, diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 00f3201452964..fc90f494e3917 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -32,7 +32,7 @@ type_matches, create_type_hint, ) -from torch.fx.passes.shape_prop import extract_tensor_metadata, ShapeProp +from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp from torch.fx.passes.split_module import split_module from torch.testing._internal.common_device_type import ( ops, @@ -96,13 +96,13 @@ def forward(self, a, b, c): # Fix for now to add type/shape to output for node in traced.graph.nodes: if node.op == "output": - node.meta["tensor_meta"] = extract_tensor_metadata(a) + node.meta["tensor_meta"] = _extract_tensor_metadata(a) for mod in module_with_submodules.modules(): if isinstance(mod, GraphModule): for node in mod.graph.nodes: - node.meta["tensor_meta"] = extract_tensor_metadata(a) + node.meta["tensor_meta"] = _extract_tensor_metadata(a) for node in module_with_submodules.graph.nodes: - node.meta["tensor_meta"] = extract_tensor_metadata(a) + node.meta["tensor_meta"] = _extract_tensor_metadata(a) weights1 = {} weights2 = {} @@ -876,7 +876,7 @@ def forward(self, x, y): traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x, y = torch.randn(3, 4), torch.randn(3, 4) - torch.testing.assert_allclose(traced(x, y), normalized(x, y)) + torch.testing.assert_close(traced(x, y), normalized(x, y)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes) ) @@ -891,7 +891,7 @@ def forward(self, x): traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x = torch.randn(3, 4) - torch.testing.assert_allclose(traced(x), normalized(x)) + torch.testing.assert_close(traced(x), normalized(x)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes) ) @@ -1413,12 +1413,12 @@ def forward(self, x): with torch.no_grad(): model = Foo().eval() optimized_model = optimization.optimize_for_inference(model) - torch.testing.assert_allclose(model(inp), optimized_model(inp)) + torch.testing.assert_close(model(inp), optimized_model(inp)) optimized_model2 = optimization.optimize_for_inference( model, pass_config={"remove_dropout": False} ) - torch.testing.assert_allclose(model(inp), optimized_model2(inp)) + torch.testing.assert_close(model(inp), optimized_model2(inp)) @skipIfNoTorchVision @skipIfNoMkldnn @@ -1450,7 +1450,7 @@ def test_optimize_for_inference_cpu_torchvision(self): orig_out = model(inp) new_out = optimized_model(inp) - torch.testing.assert_allclose(orig_out, new_out) + torch.testing.assert_close(orig_out, new_out) class TestNormalizeOperators(JitTestCase): @@ -1497,7 +1497,7 @@ def test_normalize_operator_exhaustive(self, device, dtype, op): return # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) - fx_fail = {"stack", "hstack", "vstack", "dstack", "linalg.multi_dot"} + fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"} sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) for sample_input in sample_inputs_itr: unsupported_arg_type = False diff --git a/test/test_gen_backend_stubs.py b/test/test_gen_backend_stubs.py index e1a66c69fe6f5..f788a8f34c761 100644 --- a/test/test_gen_backend_stubs.py +++ b/test/test_gen_backend_stubs.py @@ -138,11 +138,11 @@ def test_supported_invalid_op(self): self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''') # The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case. - # Only using MSNPU here because it has a valid backend key but not an autograd key- if this changes we can update the test. + # Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test. def test_backend_has_no_autograd_key_but_provides_entries(self): yaml_str = '''\ -backend: MSNPU -cpp_namespace: torch_msnpu +backend: Vulkan +cpp_namespace: torch_vulkan supported: - add autograd: @@ -155,7 +155,7 @@ def test_backend_has_no_autograd_key_but_provides_entries(self): def test_backend_autograd_kernel_mismatch_out_functional(self): yaml_str = '''\ backend: XLA -cpp_namespace: torch_msnpu +cpp_namespace: torch_xla supported: - add.Tensor autograd: @@ -168,7 +168,7 @@ def test_backend_autograd_kernel_mismatch_out_functional(self): def test_backend_autograd_kernel_mismatch_functional_inplace(self): yaml_str = '''\ backend: XLA -cpp_namespace: torch_msnpu +cpp_namespace: torch_xla supported: - add.Tensor autograd: @@ -182,7 +182,7 @@ def test_backend_autograd_kernel_mismatch_functional_inplace(self): def test_op_appears_in_supported_and_autograd_lists(self): yaml_str = '''\ backend: XLA -cpp_namespace: torch_msnpu +cpp_namespace: torch_xla supported: - add.Tensor autograd: diff --git a/test/test_indexing.py b/test/test_indexing.py index 61580910f2cfb..8b8a2ead9ed72 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -8,7 +8,8 @@ import numpy as np -from torch.testing._internal.common_utils import TestCase, run_tests, make_tensor +from torch.testing import make_tensor +from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyOnCPUAndCUDA) diff --git a/test/test_jit.py b/test/test_jit.py index 99df960da5dc4..7051d66dcf83c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -61,6 +61,8 @@ from jit.test_parametrization import TestParametrization # noqa: F401 from jit.test_attr import TestGetDefaultAttr # noqa: F401 from jit.test_aten_pow import TestAtenPow # noqa: F401 +from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 +from jit.test_union import TestUnion # noqa: F401 # Torch from torch import Tensor @@ -69,8 +71,7 @@ from torch.autograd import Variable from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.nn.utils.rnn import PackedSequence -from torch.testing import FileCheck -from torch.testing._internal.common_utils import make_tensor +from torch.testing import FileCheck, make_tensor import torch.autograd.profiler import torch.cuda import torch.jit @@ -392,11 +393,6 @@ def __init__(self, cpu_device_str): self.assertFalse(m2.p0.is_cuda) self.assertFalse(m2.b0.is_cuda) - def test_model_save_error(self): - with TemporaryFileName() as fname: - with self.assertRaisesRegex(pickle.PickleError, "not supported"): - torch.save(FooToPickle(), fname) - @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") def test_restore_device_cuda(self): class MyModule(torch.jit.ScriptModule): @@ -497,7 +493,7 @@ def forward(self, a, b, c): FileCheck().check_not("aten::relu(") \ .check("aten::_add_relu(") \ .run(m.graph) - torch.testing.assert_allclose(orig_res, new_res) + torch.testing.assert_close(orig_res, new_res) # add, relu_ a = torch.rand((7, 11)) @@ -516,7 +512,7 @@ def forward(self, a, b, c): FileCheck().check_not("aten::relu_(") \ .check("aten::_add_relu(") \ .run(m.graph) - torch.testing.assert_allclose(orig_res, new_res) + torch.testing.assert_close(orig_res, new_res) class Madd_(torch.nn.Module): def __init__(self, relu_op): @@ -547,10 +543,10 @@ def forward(self, a, b): .check_not("aten::relu_(") \ .check("aten::_add_relu_(") \ .run(m.graph) - torch.testing.assert_allclose(orig_res, new_res) + torch.testing.assert_close(orig_res, new_res) # Since _add_relu_ does inplace mutation ensure # a_copy is modified - torch.testing.assert_allclose(orig_res, a_copy) + torch.testing.assert_close(orig_res, a_copy) class Madd_out(torch.nn.Module): def __init__(self, relu_op): @@ -585,10 +581,10 @@ def forward(self, a, b): .check_not("aten::relu_(") \ .check("aten::_add_relu(") \ .run(m.graph) - torch.testing.assert_allclose(orig_res, new_res) + torch.testing.assert_close(orig_res, new_res) # Since _add_relu_ with out=a does inplace mutation ensure # a_copy is modified - torch.testing.assert_allclose(orig_res, a_copy) + torch.testing.assert_close(orig_res, a_copy) @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information") def test_peephole_optimize_shape_ops(self): @@ -2523,32 +2519,6 @@ def forward(self, input, other=four): t = Test() self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) - def test_union_to_optional(self): - def test1(u: Union[int, None]) -> int: - if u is not None: - return u - else: - return 0 - scripted = torch.jit.script(test1) - self.assertEqual(scripted(10), test1(10)) - - def test2(u: Union[None, int]) -> int: - if u is not None: - return u - else: - return 0 - scripted = torch.jit.script(test2) - self.assertEqual(scripted(40), test2(40)) - - def test3(u: Union[float, int]) -> int: - if u is not None: - return u - else: - return 0 - expected_result = "General Union types are not currently supported" - with self.assertRaisesRegex(RuntimeError, expected_result): - torch.jit.script(test3) - def test_mutable_default_values(self): with self.assertRaisesRegex(Exception, "Mutable default parameters"): @torch.jit.script @@ -5944,7 +5914,6 @@ def test_bool_arith_not(lhs): self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1) self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0) - def test_conditional_casting(self): def test_bool_cast_tensor(x): if x: @@ -8888,7 +8857,7 @@ def forward(self, x): def test_pack_unpack_state(self): sm = TestScript.DerivedStateModule() x = torch.rand(3, 4, dtype=torch.float) - torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) + torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) # Test save path self.assertFalse(sm.pack_called.item()) @@ -8899,13 +8868,14 @@ def test_pack_unpack_state(self): # ensure unpack was called after serialization so as to leave the module in an initialized state self.assertTrue(sm.unpack_called.item()) - torch.testing.assert_allclose(sm.derived, torch.neg(sm.param)) + torch.testing.assert_close(sm.derived, torch.neg(sm.param)) # Test load paths self.assertTrue(imported.unpack_called.item()) - torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) + torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") + @unittest.skipIf(True, "Skipping while landing PR stack") def test_torch_functional(self): def stft(input, n_fft): # type: (Tensor, int) -> Tensor @@ -9101,11 +9071,11 @@ def forward(self, x): return self.submod(x + self.buf) m = Mod() - torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) + torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) m.apply(lambda s: s._pack()) - torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4)) + torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4)) m.apply(lambda s: s._unpack()) - torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) + torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) def test_torch_any(self): def fn(x): @@ -9815,8 +9785,9 @@ def bar(): bar() def test_if_different_type(self): - with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int " - "in the true branch and type float in the false branch:"): + with self.assertRaisesRegex(RuntimeError, "c0 is set to type " + "int in the true branch and type " + "float in the false branch"): @torch.jit.script def diff_type_used(): if 1 == 2: @@ -9825,7 +9796,7 @@ def diff_type_used(): c0 = 1.0 return c0 - with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously has type float"): + with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"): @torch.jit.script def diff_existing_type(x): c0 = 1.0 @@ -10608,7 +10579,7 @@ def f5(a): with self.assertRaisesRegex(RuntimeError, r'Expected a value of' r' type \'List\[int\]\' for argument' r' \'size\' but instead found type ' - r'\'List\[Any\]\''): + r'\'List\[Union\[List\[int\], int\]\]'): @torch.jit.script def f6(a): a.expand(size=[3, [4]]) @@ -10774,6 +10745,68 @@ def addmm_grad_test(b, x, w): self.assertEqual(w.grad, w_ref.grad) self.assertEqual(b.grad, b_ref.grad) + @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix") + def test_batch_norm_inference_backward_cuda(self): + with enable_profiling_mode_for_profiling_tests(): + class MyBatchNorm(torch.nn.Module): + def __init__(self, num_features, affine, track_running_stats): + super(MyBatchNorm, self).__init__() + self.bn = torch.nn.BatchNorm2d( + num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float() + + def forward(self, x: torch.Tensor): + o = self.bn(x) + o = torch.nn.functional.relu(o) + return o + + batch = 4 + c = 2 + hw = 3 + # Initialize param and input values + x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() + grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() + + training = False + affine = True + track_running_stats = True + + module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda() + ref_module = MyBatchNorm(c, affine, track_running_stats).cuda() + module.eval() + ref_module.eval() + + jit_module = torch.jit.script(module) + ref_module.load_state_dict(module.state_dict()) + + x = x_init.detach().clone() + x.requires_grad_() + x_ref = x_init.detach().clone() + x_ref.requires_grad_() + + # Test symbolic differentiation + # Run Forward and Backward thrice to trigger autodiff graph + for i in range(0, 3): + y = jit_module(x) + y.backward(grad) + x.grad.zero_() + + module.bn.running_mean.zero_() + module.bn.running_var.fill_(1.0) + ref_module.bn.running_mean.zero_() + ref_module.bn.running_var.fill_(1.0) + + # run jitted module + y = jit_module(x) + y.backward(grad) + # reference computation + y_ref = ref_module(x_ref) + y_ref.backward(grad) + + self.assertEqual(y_ref, y) + self.assertEqual(x.grad, x_ref.grad) + self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean) + self.assertEqual(module.bn.running_var, ref_module.bn.running_var) + def test_zeros(self): class M(torch.jit.ScriptModule): __constants__ = ['d'] @@ -10958,7 +10991,7 @@ def forward(self, x): torch._C._jit_pass_remove_dropout(m._c) res = m(data) FileCheck().check_not("aten::dropout").run(str(m.graph)) - torch.testing.assert_allclose(ref_res, res, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3) def test_unfold_zero_dim(self): def fn(x): @@ -12616,7 +12649,7 @@ def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: for pair in self.type_input_return_pairs(): cu = torch.jit.CompilationUnit(self.format_code(code, pair)) test_str.append(str(cu.foo.schema)) - self.assertExpected("\n".join(test_str)) + self.assertExpected("\n".join(test_str) + "\n") # String frontend , Python 3-style type annotations , Script method def test_annot_string_py3_method(self): @@ -12635,7 +12668,7 @@ def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output tm = TestModule() tm.define(self.format_code(code, pair)) test_str.append(str(tm.foo.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") # String frontend , MyPy-style type comments , Script function def test_annot_string_mypy_fn(self): @@ -12648,7 +12681,7 @@ def foo(x, y): for pair in self.type_input_return_pairs(): cu = torch.jit.CompilationUnit(self.format_code(code, pair)) test_str.append(str(cu.foo.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") # String frontend , MyPy-style type comments , Script method def test_annot_string_mypy_method(self): @@ -12669,7 +12702,7 @@ def foo(self, x, y): tm = TestModule() tm.define(self.format_code(code, pair)) test_str.append(str(tm.foo.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") # Python AST Frontend , Python 3-style type annotations , Script function def test_annot_ast_py3_fn(self): @@ -12686,7 +12719,7 @@ def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: for pair in self.type_input_return_pairs(): fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') test_str.append(str(fn.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") def test_multiline_annot_ast_py3_fn(self): code = dedent(''' @@ -12761,7 +12794,7 @@ def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output for pair in self.type_input_return_pairs(): fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') test_str.append(str(fn.foo.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") # Python AST Frontend , MyPy-style type comments , Script function def test_annot_ast_mypy_fn(self): @@ -12777,7 +12810,7 @@ def foo(x, y): for pair in self.type_input_return_pairs(): fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') test_str.append(str(fn.schema)) - self.assertExpected("\n".join(test_str)) + self.assertExpected("\n".join(test_str) + "\n") # Python AST Frontend , MyPy-style type comments , Script method def test_annot_ast_mypy_method(self): @@ -12795,7 +12828,7 @@ def foo(self, x, y): for pair in self.type_input_return_pairs(): fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') test_str.append(str(fn.foo.schema)) - self.assertExpectedStripMangled("\n".join(test_str)) + self.assertExpectedStripMangled("\n".join(test_str) + "\n") # Tests that "# type: ignore[*]" is supported in type lines and is # properly ignored. @@ -13465,8 +13498,8 @@ def fn(x): self.checkScript(fn, ("y")) def index_str_to_tensor(s): - # type: (str) -> int - return torch.tensor(ord(s)) + # type: (str) -> Tensor + return torch.tensor(ord(s)) # noqa: T484 s = u'\u00a3'.encode('utf8')[:1] self.checkScript(index_str_to_tensor, (s,)) @@ -14893,7 +14926,7 @@ def jit_multihead_attn_forward(query, # type: Tensor attn_mask=mask)[0] # print("rel. error: ") # print(jit_out / py_out - 1) - self.assertTrue(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) + self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_multi_head_attn_cuda(self): @@ -14929,7 +14962,7 @@ def forward(self, q, k, v): None, None, None, 0.0, model.mod.out_proj.weight, model.mod.out_proj.bias)[0] - self.assertTrue(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) + self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_transformer_cuda(self): @@ -14968,7 +15001,7 @@ def forward(self, q, k): # print(jit_out/py_out-1) # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) - self.assertTrue(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) + self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) def test_list_python_op(self): def python_list_op(lst): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index ba47547256b75..a6cc085b27c70 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -85,10 +85,6 @@ def setUp(self): self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() torch._C._jit_set_te_must_use_llvm_cpu(False) - # TODO: CPU fuser currently is disabled when multithreading. - self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled() - torch._C._jit_set_texpr_parallel_cpu_enabled(True) - self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] self.int_dtypes = [ torch.int8, @@ -98,10 +94,10 @@ def setUp(self): torch.bool, ] self.fp_dtypes = [ - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, + torch.bfloat16, ] self.dtypes = self.int_dtypes + self.fp_dtypes @@ -116,7 +112,6 @@ def tearDown(self): torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) - torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel) def assertLastGraphAllFused(self): self.assertAllFused(torch.jit.last_executed_optimized_graph()) @@ -1135,8 +1130,7 @@ def foo(x): dtypes = [ torch.bool, torch.int, - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, ] @@ -1151,6 +1145,9 @@ def forward(self, x): bad_dtypes = [] for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes): + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue if dtype == output_dtype: continue @@ -1186,7 +1183,7 @@ def fn(input_v, mask): ref = fn(input_v, mask) try: t = torch.jit.trace(fn, (input_v, mask)) - torch.testing.assert_allclose(ref, t(input_v, mask)) + torch.testing.assert_close(ref, t(input_v, mask)) print(torch.jit.last_executed_optimized_graph()) self.assertLastGraphAllFused() except Exception as e: @@ -1206,18 +1203,16 @@ def test_isnan(self): torch.int16, torch.int32, torch.int64, - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, torch.bool, ] for inp, device, dtype in product(inputs, self.devices, dtypes): - # TODO - if dtype == torch.float16 and not LLVM_ENABLED: + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": continue - inp = inp.to(device=device, dtype=dtype) try: f = torch.jit.trace(lambda x: x.isnan(), (inp,)) @@ -1269,13 +1264,20 @@ def apply(fn): torch.round, torch.trunc, torch.frac, - F.hardshrink, + # TODO: broken on ROCm? + # F.hardshrink, F.leaky_relu, lambda x: torch.threshold(x, 0, -10), lambda x: torch.clamp(x, -10, 10), ] + gpu_only = {torch.erf, torch.erfc} sizes = [(1,), (2,), (4, 4)] for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue + if op in gpu_only and device == "cpu": + continue try: x = self.data_for(dtype, device, size=size) fn = apply(op) @@ -1287,7 +1289,7 @@ def apply(fn): continue try: t = torch.jit.trace(fn, (x,)) - torch.testing.assert_allclose(ref, t(x)) + torch.testing.assert_close(ref, t(x)) self.assertAllFused(t.graph_for(x)) except Exception as e: raise RuntimeError( @@ -1325,6 +1327,8 @@ def apply(fn): ] devices = self.devices for dtype, op, device in product(self.dtypes, binary_ops, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1375,6 +1379,8 @@ def fn(x, y): "[[10, 3, 4], [4, 5]]", ] for dtype, size, device in product(self.dtypes, sizes, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: size_x, size_y = size x = self.data_for(dtype, device, size=size_x) @@ -1419,6 +1425,8 @@ def apply_with_scalar(fn, scalar): # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1451,6 +1459,8 @@ def apply_with_scalar(fn, scalar): # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1486,6 +1496,8 @@ def apply_with_scalar(fn, scalar): # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1514,6 +1526,8 @@ def apply(fn): ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1543,6 +1557,8 @@ def apply(fn): ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device, size=[5, 3, 128, 128]) y = self.data_for(dtype, device, size=[3]) @@ -1574,6 +1590,8 @@ def apply(fn): torch.cat, ] for dtype, op, device in product(self.dtypes, list_ops, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: x = self.data_for(dtype, device, size=[5, 4, 1, 7]) y = self.data_for(dtype, device, size=[5, 4, 1, 7]) @@ -1605,6 +1623,8 @@ def apply(fn): ] devices = self.devices for dtype, op, device in product(self.dtypes, ops, devices): + if dtype in [torch.float16, torch.bfloat16] and device == "cpu": + continue try: cond = self.data_for(torch.bool, device) x = self.data_for(dtype, device) @@ -1632,7 +1652,6 @@ def fn(x): unsupported_dtypes = [ torch.uint8, - torch.bfloat16, torch.complex32, torch.complex64, torch.complex128, @@ -1683,7 +1702,7 @@ def eager(t0, t1, t2, t3, t4): for _ in range(4): for pair in zip(script(*inputs), eager(*inputs)): test, ref = pair - torch.testing.assert_allclose(test, ref) + torch.testing.assert_close(test, ref) self.assertAllFused(script.graph_for(*inputs)) def test_sub_gt_and(self): @@ -1770,16 +1789,20 @@ def test_type_as_cat(self): with inline_fusion_groups(): def eager(x, y): return torch.cat((x, y.type_as(x)), dim=1) - for dtype1, dtype2 in product(self.dtypes, self.dtypes): + dtypes = self.dtypes.copy() + # CPU fuser doesn't support float16. + dtypes.remove(torch.float16) + dtypes.remove(torch.bfloat16) + for dtype1, dtype2 in product(dtypes, dtypes): x = torch.randint(2, (1, 13,)).to(dtype1) zero = torch.tensor([[0]]).to(dtype2) one = torch.tensor([[1]]).to(dtype2) script = torch.jit.trace(eager, (x, zero)) for _ in range(3): - torch.testing.assert_allclose( + torch.testing.assert_close( script(x, zero), eager(x, zero)) - torch.testing.assert_allclose( + torch.testing.assert_close( script(x, one), eager(x, one)) self.assertAllFused(script.graph_for(x, one)) @@ -1824,7 +1847,7 @@ def _test_fwd_bwd(self, fn): xs -= 0.1 * xs.grad x.grad = None xs.grad = None - torch.testing.assert_allclose(y, ys) + torch.testing.assert_close(y, ys) def test_relu_fwd_bwd(self): def eager(x): @@ -1907,12 +1930,36 @@ def eager(x): for _ in range(3): script(x) - torch.testing.assert_allclose(eager(x), script(x)) + torch.testing.assert_close(eager(x), script(x)) # Now when an input hits the unrolled path, it will produce an # incorrectly-sized tensor, since size=1 has been burned in. x = torch.ones((8, 1)) - torch.testing.assert_allclose(eager(x), script(x)) + torch.testing.assert_close(eager(x), script(x)) + + def test_batch_norm(self): + def test(fn, args): + trace = torch.jit.trace(fn, args) + self.assertAllFused(trace.graph_for(*args)) + torch.testing.assert_allclose(fn(*args), trace(*args)) + + def bn(i, x): + return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() + + def bn_no_weight(i, x): + return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() + + def bn_no_bias(i, x): + return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() + + def bn_neither(i, x): + return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() + + for device in self.devices: + i = torch.randn(4, 16, 32, 40, device=device) + x = torch.randn(16, device=device) + for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: + test(fn, (i, x)) works_list = [ '__radd__', @@ -1934,6 +1981,8 @@ def eager(x): 'cosh', 'div.no_rounding_mode', 'div.true_rounding', + 'div.floor_rounding', + 'div.trunc_rounding', 'eq', 'erf', 'erfc', diff --git a/test/test_linalg.py b/test/test_linalg.py index 8ba3373d38ce4..5912111da4c0a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -14,14 +14,18 @@ from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, - TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, + TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, gradcheck, gradgradcheck) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA, onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver) -from torch.testing import floating_and_complex_types, floating_types, all_types +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, + get_all_fp_dtypes, +) from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9 from torch.distributions.binomial import Binomial @@ -89,7 +93,7 @@ def check(a_sizes_, b_sizes_): # Tests torch.outer, and its alias, torch.ger, vs. NumPy @precisionOverride({torch.bfloat16: 1e-1}) - @dtypes(*(torch.testing.get_all_dtypes())) + @dtypes(*(get_all_dtypes())) def test_outer(self, device, dtype): def run_test_case(a, b): if dtype == torch.bfloat16: @@ -483,10 +487,10 @@ def test_cholesky_errors_and_warnings(self, device, dtype): r'1-dimensional array given\. Array must be at least two-dimensional'): np.linalg.cholesky(A.cpu().numpy()) - # if the input matrix is singular, an error should be raised + # if the input matrix is not positive definite, an error should be raised A = torch.eye(3, 3, dtype=dtype, device=device) - A[-1, -1] = 0 # Now A is singular - with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + A[-1, -1] = 0 # Now A is not positive definite + with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'): torch.linalg.cholesky(A) with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'): np.linalg.cholesky(A.cpu().numpy()) @@ -495,8 +499,8 @@ def test_cholesky_errors_and_warnings(self, device, dtype): A = torch.eye(3, 3, dtype=dtype, device=device) A = A.reshape((1, 3, 3)) A = A.repeat(5, 1, 1) - A[4, -1, -1] = 0 # Now A[4] is singular - with self.assertRaisesRegex(RuntimeError, r'For batch 4: U\(3,3\) is zero, singular U\.'): + A[4, -1, -1] = 0 # Now A[4] is not positive definite + with self.assertRaisesRegex(RuntimeError, r'\(Batch element 4\): The factorization could not be completed'): torch.linalg.cholesky(A) # if out tensor with wrong shape is passed a warning is given @@ -674,7 +678,7 @@ def test_cholesky_ex_non_pd(self, device, dtype): A[-1, -1] = 0 # Now A is singular _, info = torch.linalg.cholesky_ex(A) self.assertEqual(info, 3) - with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + with self.assertRaisesRegex(RuntimeError, r'minor of order 3 is not positive-definite'): torch.linalg.cholesky_ex(A, check_errors=True) # if at least one matrix in the batch is not positive definite, @@ -688,7 +692,7 @@ def test_cholesky_ex_non_pd(self, device, dtype): expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) expected_info[3] = 2 self.assertEqual(info, expected_info) - with self.assertRaisesRegex(RuntimeError, r'For batch 3: U\(2,2\) is zero, singular U\.'): + with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The factorization could not be completed'): torch.linalg.cholesky_ex(A, check_errors=True) @skipCUDAIfNoMagmaAndNoCusolver @@ -772,7 +776,7 @@ def check(m, a, b, beta, alpha): check(m_scalar, a, b, beta, alpha) # test nans and infs are not propagated to the output when beta == 0 - float_and_complex_dtypes = torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes() + float_and_complex_dtypes = get_all_fp_dtypes() + get_all_complex_dtypes() if beta == 0 and dtype in float_and_complex_dtypes: m[0][10] = m[10][10] = m[20][20] = float('inf') m[1][10] = m[11][10] = m[21][20] = float('nan') @@ -785,7 +789,7 @@ def test_addr_bool(self, device, dtype): self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False) self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True) - @dtypes(*(torch.testing.get_all_int_dtypes())) + @dtypes(*(get_all_int_dtypes())) def test_addr_integral(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'argument beta must not be a floating point number.'): @@ -806,7 +810,7 @@ def test_addr_integral(self, device, dtype): self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2) @precisionOverride({torch.bfloat16: 1e-1}) - @dtypes(*(torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) def test_addr_float_and_complex(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'Boolean beta only supported for Boolean results.'): @@ -819,11 +823,11 @@ def test_addr_float_and_complex(self, device, dtype): self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2) # when beta is not zero self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2) - if dtype in torch.testing.get_all_complex_dtypes(): + if dtype in get_all_complex_dtypes(): self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j)) - @dtypes(*itertools.product(torch.testing.get_all_dtypes(), - torch.testing.get_all_dtypes())) + @dtypes(*itertools.product(get_all_dtypes(), + get_all_dtypes())) def test_outer_type_promotion(self, device, dtypes): a = torch.randn(5).to(device=device, dtype=dtypes[0]) b = torch.randn(5).to(device=device, dtype=dtypes[1]) @@ -831,9 +835,9 @@ def test_outer_type_promotion(self, device, dtypes): result = op(a, b) self.assertEqual(result.dtype, torch.result_type(a, b)) - @dtypes(*itertools.product(torch.testing.get_all_dtypes(), - torch.testing.get_all_dtypes(), - torch.testing.get_all_dtypes())) + @dtypes(*itertools.product(get_all_dtypes(), + get_all_dtypes(), + get_all_dtypes())) def test_addr_type_promotion(self, device, dtypes): a = make_tensor((5,), device=device, dtype=dtypes[0], low=-2, high=2) b = make_tensor((5,), device=device, dtype=dtypes[1], low=-2, high=2) @@ -2892,6 +2896,16 @@ def test_svd_errors_and_warnings(self, device, dtype): # error from out_v svd(a, out=(out_u, out_s, out_v)) + # if input contains NaN then an error is triggered for svd + a = torch.full((3, 3), float('nan'), dtype=dtype, device=device) + a[0] = float('nan') + with self.assertRaisesRegex(RuntimeError, "The algorithm failed to converge"): + svd(a) + a = torch.randn(3, 33, 33, dtype=dtype, device=device) + a[1, 0, 0] = float('nan') + with self.assertRaisesRegex(RuntimeError, r"\(Batch element 1\): The algorithm failed to converge"): + svd(a) + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @@ -3237,7 +3251,7 @@ def test_inv_ex_singular(self, device, dtype): A[-1, -1] = 0 # Now A is singular info = torch.linalg.inv_ex(A).info self.assertEqual(info, 3) - with self.assertRaisesRegex(RuntimeError, r'U\(3,3\) is zero, singular U\.'): + with self.assertRaisesRegex(RuntimeError, r'diagonal element 3 is zero, the inversion could not be completed'): torch.linalg.inv_ex(A, check_errors=True) # if at least one matrix in the batch is not positive definite, @@ -3251,7 +3265,7 @@ def test_inv_ex_singular(self, device, dtype): expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device) expected_info[3] = 2 self.assertEqual(info, expected_info) - with self.assertRaisesRegex(RuntimeError, r'For batch 3: U\(2,2\) is zero, singular U\.'): + with self.assertRaisesRegex(RuntimeError, r'\(Batch element 3\): The diagonal element 2 is zero'): torch.linalg.inv_ex(A, check_errors=True) @slowTest @@ -3289,7 +3303,7 @@ def test_inverse_errors(self, device, dtype): def run_test_singular_input(batch_dim, n): x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) x[n, -1, -1] = 0 - with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): + with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): torch.inverse(x) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -3306,7 +3320,7 @@ def test_inverse_errors_large(self, device, dtype): x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device) x[:] = torch.eye(616, dtype=dtype, device=device) x[..., 10, 10] = 0 - with self.assertRaisesRegex(RuntimeError, r'For batch 0: U\(11,11\) is zero'): + with self.assertRaisesRegex(RuntimeError, r'\(Batch element 0\): The diagonal element 11 is zero'): torch.inverse(x) @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7}) @@ -3428,7 +3442,7 @@ def test_inv_errors_and_warnings(self, device, dtype): def run_test_singular_input(batch_dim, n): a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) a[n, -1, -1] = 0 - with self.assertRaisesRegex(RuntimeError, rf"For batch {n}: U\(3,3\) is zero"): + with self.assertRaisesRegex(RuntimeError, rf"\(Batch element {n}\): The diagonal element 3 is zero"): torch.linalg.inv(a) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -3559,7 +3573,7 @@ def run_test_singular_input(batch_dim, n): a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1) a[n, -1, -1] = 0 b = torch.randn(batch_dim, 3, 1, dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'): + with self.assertRaisesRegex(RuntimeError, rf'\(Batch element {n}\): The diagonal element 3 is zero'): torch.linalg.solve(a, b) for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]: @@ -3977,8 +3991,14 @@ def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn): def check(x, y): # Compare with numpy res = torch_fn(x, y) - ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) - self.assertEqual(res.cpu(), ref) + if x.dtype == torch.bfloat16: + ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy()))) + else: + ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy()))) + if res.dtype == torch.bfloat16: + self.assertEqual(res.cpu(), ref.bfloat16()) + else: + self.assertEqual(res.cpu(), ref) # Test out variant out = torch.empty_like(res) @@ -3991,19 +4011,20 @@ def check(x, y): check(x, y) # Contiguous - x = torch.randn(10, dtype=dtype, device=device) - y = torch.randn(10, dtype=dtype, device=device) + x = torch.randn(200, dtype=dtype, device=device) + y = torch.randn(200, dtype=dtype, device=device) check(x, y) # 0 strided - y = torch.randn(1, dtype=dtype, device=device).expand(10) + y = torch.randn(1, dtype=dtype, device=device).expand(200) check(x, y) # 2 strided check(x[::2], y[::2]) - @dtypes(torch.float, torch.cfloat) - @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5}) + @dtypes(torch.float, torch.cfloat, torch.bfloat16) + @dtypesIfCUDA(torch.float, torch.cfloat) + @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0}) def test_dot_vs_numpy(self, device, dtype): self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot) @@ -4912,7 +4933,7 @@ def test_triangular_solve_singular(self, device, dtype): b = torch.rand(3, 1, dtype=dtype, device=device) A = torch.eye(3, 3, dtype=dtype, device=device) A[-1, -1] = 0 # Now A is singular - err_str = r"triangular_solve: U\(3,3\) is zero, singular U\." + err_str = r"triangular_solve: The diagonal element 3 is zero" with self.assertRaisesRegex(RuntimeError, err_str): torch.triangular_solve(b, A) @@ -5277,8 +5298,8 @@ def call_torch_fn(*args, **kwargs): self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) @dtypesIfCUDA(torch.cfloat, torch.cdouble, - *torch.testing.get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater))) - @dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool})) + *get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater))) + @dtypes(*(set(get_all_dtypes()) - {torch.half, torch.bool})) def test_blas_alpha_beta_empty(self, device, dtype): # This test is disabled on CUDA 9 due to: # See: https://github.com/pytorch/pytorch/issues/31006 @@ -5314,7 +5335,7 @@ def test_blas_alpha_beta_empty(self, device, dtype): self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device), torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) - @dtypes(*(torch.testing.get_all_complex_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_complex_dtypes() + get_all_fp_dtypes())) def test_blas_nan_out(self, device, dtype): # These functions should work correctly with NaN filled outputs, # but need special handling, see [NOTE: cpu_zero] @@ -5940,9 +5961,9 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out= @precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) - @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), - *torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)), - include_half=(not TEST_WITH_ROCM))) + @dtypesIfCUDA(*get_all_complex_dtypes(), + *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)), + include_half=(not TEST_WITH_ROCM))) @dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble) def test_addmv(self, device, dtype): # have to use torch.randn(...).to(bfloat16) instead of @@ -5976,7 +5997,7 @@ def test_addmv(self, device, dtype): for m, v in itertools.product(ms, vs): self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) + @dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) @dtypes(torch.float, torch.double) def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): # tests (o, s)*(s). o is output size, s is summed size. @@ -6007,9 +6028,9 @@ def _test(row_major, incx, incy, lda_tail): @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) - @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), - *torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) - @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_complex_dtypes(), + *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) + @dtypes(*get_all_complex_dtypes(), *get_all_fp_dtypes()) @tf32_on_and_off(0.05) def test_addmm(self, device, dtype): M = torch.randn(10, 25, device=device).to(dtype) @@ -6042,7 +6063,7 @@ def maybe_transpose(cond, m): self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) @dtypes(torch.float, torch.double) - @dtypesIfCUDA(*([torch.float, torch.double] + torch.testing.get_all_complex_dtypes())) + @dtypesIfCUDA(*([torch.float, torch.double] + get_all_complex_dtypes())) @tf32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: @@ -6150,12 +6171,12 @@ def genf_int(x, y): return torch.randint(0, 100, (x, y), dtype=dtype, device=device) def genf_bfloat(x, y): - return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) + return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 def genf_float(x, y): return torch.randn(x, y, dtype=dtype, device=device) - for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: + for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: if (dtype == torch.int32) or (dtype == torch.int64): genf = genf_int elif (dtype == torch.bfloat16): @@ -6165,6 +6186,38 @@ def genf_float(x, y): _test_mm(n, m, p, dtype, genf) + @onlyOnCPUAndCUDA + def test_mm_bmm_non_memory_dense(self, device): + def _slice(tensor, fn): + return fn(tensor)[..., ::2] + A = torch.randn(3, 6, dtype=torch.cfloat, device=device) + B = torch.randn(3, 3, dtype=torch.cfloat, device=device) + out = torch.empty(3, 3, device=device, dtype=torch.complex64).t() + out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t() + A_conj = _slice(A, torch.conj) + A_conj_physical = _slice(A, torch.conj_physical) + + self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out)) + self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out)) + + Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device) + Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device) + Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3) + out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).transpose(-1, -2) + + Ab_conj = _slice(Ab, torch.conj) + Ab_conj_physical = _slice(Ab, torch.conj_physical) + + def t_b(tensor): + return tensor.transpose(-1, -2) + + self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b)) + self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b)) + + # test broadcasting + self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b)) + self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b)) + @onlyOnCPUAndCUDA @dtypes(torch.float32, torch.float64) def test_strided_mm_bmm(self, device, dtype): @@ -6184,7 +6237,7 @@ def test_strided_mm_bmm(self, device, dtype): @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1") @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_bmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: @@ -6194,7 +6247,7 @@ def test_bmm(self, device, dtype): return batch_sizes = [1, 10] - M, N, O = 23, 8, 12 + M, N, O = 23, 15, 12 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 is_supported = True @@ -6216,8 +6269,8 @@ def invert_perm(p): def generate_inputs(num_batches): # transposed tensors for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): - b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) - b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = make_tensor((num_batches, M, N), device, dtype, low=-0.1, high=0.1) + b2 = make_tensor((num_batches, N, O), device, dtype, low=-0.1, high=0.1) b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) yield b1, b2 @@ -6225,8 +6278,8 @@ def generate_inputs(num_batches): for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + b1 = make_tensor(shape1, device, dtype, low=-0.1, high=0.1).expand(num_batches, M, N) + b2 = make_tensor(shape2, device, dtype, low=-0.1, high=0.1).expand(num_batches, N, O) yield b1, b2 # zero-sized tensors for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): @@ -6296,7 +6349,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_addbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: @@ -6306,7 +6359,7 @@ def test_addbmm(self, device, dtype): return num_batches = 2 - M, N, O = 2, 3, 4 + M, N, O = 16, 17, 18 is_supported = True if dtype == torch.bfloat16: @@ -6332,8 +6385,8 @@ def generate_tensor(): # transposed tensors for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): for perm3 in itertools.permutations((0, 1)): - b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) - b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) + b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) * 0.1 + b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) * 0.1 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) ref = torch.from_numpy( @@ -6345,8 +6398,8 @@ def generate_tensor(): for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1).expand(num_batches, M, N) * 0.1 + b2 = make_tensor(shape2, device, dtype, low=-1, high=1).expand(num_batches, N, O) * 0.1 ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype).sum(0) @@ -6356,8 +6409,8 @@ def generate_tensor(): for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) - b1 = make_tensor(shape1, device, dtype, low=-1, high=1) - b2 = make_tensor(shape2, device, dtype, low=-1, high=1) + b1 = make_tensor(shape1, device, dtype, low=-1, high=1) * 0.1 + b2 = make_tensor(shape2, device, dtype, low=-1, high=1) * 0.1 ref = torch.from_numpy( b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() ).to(device=device, dtype=dtype).sum(0) @@ -6369,7 +6422,7 @@ def generate_tensor(): @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_baddbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: @@ -6379,7 +6432,7 @@ def test_baddbmm(self, device, dtype): return num_batches = 10 - M, N, O = 12, 8, 5 + M, N, O = 12, 8, 50 is_supported = True if dtype == torch.bfloat16 and self.device_type == 'cuda': @@ -7253,7 +7306,7 @@ def test_cholesky_inverse_errors_and_warnings(self, device, dtype): a = torch.randn(3, 3, device=device, dtype=dtype) a[1, 1] = 0 if self.device_type == 'cpu': - with self.assertRaisesRegex(RuntimeError, r"cholesky_inverse: U\(2,2\) is zero, singular U\."): + with self.assertRaisesRegex(RuntimeError, r"cholesky_inverse: The diagonal element 2 is zero"): torch.cholesky_inverse(a) # cholesky_inverse on GPU does not raise an error for this case elif self.device_type == 'cuda': diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index 78ebb550d0227..19f07e2454488 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -119,7 +119,7 @@ def forward(self, x): .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) - torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ @@ -131,7 +131,7 @@ def forward(self, x): .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) - torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} @@ -142,7 +142,7 @@ def forward(self, x): .check_not("prepacked::linear_clamp_run") \ .check_not("prepacked::conv2d_clamp_run") \ .run(optimized_scripted_model_no_prepack.graph) - torch.testing.assert_allclose(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) bn_test_module = BNTestModule() @@ -157,14 +157,14 @@ def forward(self, x): bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) - torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) + torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) bn_input = torch.rand(1, 1, 6, 6) - torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) + torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) class MyMobileOptimizedTagTest(torch.nn.Module): def __init__(self): @@ -231,7 +231,7 @@ def foo(self, x): FileCheck().check_not("dropout.__") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) - torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) class BNTestNoForwardModule(torch.nn.Module): def __init__(self): @@ -257,7 +257,7 @@ def foo(self, x): bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) - torch.testing.assert_allclose( + torch.testing.assert_close( bn_no_forward_scripted_module.foo(bn_input), bn_fold_no_forward_scripted_module.foo(bn_input), rtol=1e-2, @@ -493,7 +493,7 @@ def _quant_script_and_optimize(model): data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) - torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) # generic case @@ -507,7 +507,7 @@ def _quant_script_and_optimize(model): data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) - torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision") def test_mobilenet_optimize_for_mobile(self): diff --git a/test/test_modules.py b/test/test_modules.py index bb0fe5f1f9689..6d6adbc7ac57d 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1,3 +1,5 @@ +import tempfile + import torch from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_modules import module_db, modules @@ -108,6 +110,50 @@ def test_factory_kwargs(self, device, dtype, module_info): buffer.dtype, dtype, f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}') + @modules(module_db) + def test_repr(self, device, dtype, module_info): + # Test module can be represented with repr and str without errors. + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, + requires_grad=False) + for module_input in module_inputs: + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + m = module_cls(*args, **kwargs) + + # Check that these methods do not raise errors + m.__repr__() + str(m) + + @modules(module_db) + def test_pickle(self, device, dtype, module_info): + # Test that module can be pickled and unpickled. + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, + requires_grad=False) + for module_input in module_inputs: + if module_input.forward_input is None: + continue + + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + + with freeze_rng_state(): + # === Instantiate the module. === + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + m = module_cls(*args, **kwargs) + m.to(device).to(dtype) + + # === Do forward pass. === + args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs + output = m(*args, **kwargs) + + # === Check unpickled module gives the same output. === + with tempfile.TemporaryFile() as f: + torch.save(m, f) + f.seek(0) + m_copy = torch.load(f) + output_from_copy = m_copy(*args, **kwargs) + self.assertEqual(output, output_from_copy) + instantiate_device_type_tests(TestModule, globals()) diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index b5e7aac402abb..2c6d2d80a2266 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1072,6 +1072,11 @@ def test_flatten(self): with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): tensor.flatten(['H', 'D', 'W'], 'features') + def test_flatten_nodims(self): + tensor = torch.empty((2, 3)) + with self.assertRaisesRegex(RuntimeError, "cannot be empty"): + tensor.flatten((), 'abcd') + def test_unflatten(self): # test args: tensor, int, namedshape self.assertTrue(torch.equal( diff --git a/test/test_nn.py b/test/test_nn.py index ccf6f6e933c10..2d66477ff826a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -33,7 +33,7 @@ from torch.nn import Parameter from torch.nn.parameter import UninitializedParameter, UninitializedBuffer from torch.nn.parallel._functions import Broadcast -from torch.testing import get_all_fp_dtypes +from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \ get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \ @@ -229,7 +229,7 @@ def test_doubletensor_avg_pool2d(self): actual = torch.nn.functional.avg_pool2d(input[0], (i, j)) actual = actual.view(1, actual.numel()) expected = self._avg_pool2d(input, (i, j)) - self.assertTrue(torch.allclose(actual, expected, rtol=0, atol=1e-5)) + self.assertEqual(actual, expected, rtol=0, atol=1e-5) def test_avg_pool2d_with_zero_divisor(self): self.assertRaisesRegex(RuntimeError, "divisor must be not zero", @@ -244,7 +244,7 @@ def test_doubletensor_avg_pool2d_with_divisor(self): actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor) actual = actual.view(1, actual.numel()) expected = self._sum_pool2d(input, (i, j)) / divisor - self.assertTrue(torch.allclose(actual, expected, rtol=0, atol=1e-5)) + self.assertEqual(actual, expected, rtol=0, atol=1e-5) def test_doubletensor_avg_pool3d(self): h, w, d = 5, 6, 7 @@ -255,7 +255,7 @@ def test_doubletensor_avg_pool3d(self): actual = torch.nn.functional.avg_pool3d(input.unsqueeze(0), (i, j, k)) actual = actual.view(1, actual.numel()) expected = self._avg_pool3d(input, (i, j, k)) - self.assertTrue(torch.allclose(actual, expected, rtol=0, atol=1e-5)) + self.assertEqual(actual, expected, rtol=0, atol=1e-5) def test_doubletensor_avg_pool3d_with_divisor(self): h, w, d = 6, 5, 7 @@ -267,7 +267,7 @@ def test_doubletensor_avg_pool3d_with_divisor(self): actual = torch.nn.functional.avg_pool3d(input.unsqueeze(0), (i, j, k), divisor_override=divisor) actual = actual.view(1, actual.numel()) expected = self._sum_pool3d(input, (i, j, k)) / divisor - self.assertTrue(torch.allclose(actual, expected, rtol=0, atol=1e-5)) + self.assertEqual(actual, expected, rtol=0, atol=1e-5) def test_avg_pool3d_with_zero_divisor(self): self.assertRaisesRegex(RuntimeError, "divisor must be not zero", @@ -2260,7 +2260,7 @@ def forward(self, x): self.assertNotIn("weight", model._parameters) # Result should be skew-symmetric A = model.weight - self.assertTrue(torch.allclose(A, -A.T)) + self.assertEqual(A, -A.T) # Remove and check consistency parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) @@ -2277,7 +2277,7 @@ def forward(self, x): self.assertNotIn("weight", model._parameters) # Result should be skew-symmetric A = model.weight - self.assertTrue(torch.allclose(A, -A.T)) + self.assertEqual(A, -A.T) # Remove and check consistency parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) self.assertFalse(hasattr(model, "parametrizations")) @@ -2291,7 +2291,7 @@ def forward(self, x): # Result should be orthogonal X = model.weight Id = torch.eye(X.size(0), device=X.device) - self.assertTrue(torch.allclose(X.T @ X, Id)) + self.assertEqual(X.T @ X, Id) # Structure tests self.assertTrue(hasattr(model, "parametrizations")) self.assertTrue(parametrize.is_parametrized(model)) @@ -2810,10 +2810,10 @@ def right_inverse(self, w): init_weight = model.weight.clone() parametrize.register_parametrization(model, "weight", RankOne()) # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix - self.assertTrue(torch.allclose(init_weight, model.weight)) + self.assertEqual(init_weight, model.weight) parametrize.register_parametrization(model, "weight", Double()) # The matrix now is twice the initial matrix - self.assertTrue(torch.allclose(2.0 * init_weight, model.weight)) + self.assertEqual(2.0 * init_weight, model.weight) # Multiplying by a scalar does not change the rank self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) @@ -4220,6 +4220,9 @@ def fn(input): out1 = wrapped_m(input) return out0 + out1 + # Make sure we can compute gradients wrt to all the parameters in the case + # of double forward + fn(input.clone().requires_grad_()).sum().backward() gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) # test removing @@ -4515,6 +4518,139 @@ def test_spectral_norm_pickle(self): m = pickle.loads(pickle.dumps(m)) self.assertIsInstance(m, nn.Linear) + def test_orthogonal_parametrization(self): + # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) + + def assert_is_orthogonal(X): + n, k = X.size(-2), X.size(-1) + if n < k: + X = X.transpose(-2, -1) + n, k = k, n + Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k) + eps = 10 * n * torch.finfo(X.dtype).eps + torch.testing.assert_allclose(X.transpose(-2, -1).conj() @ X, Id, atol=eps, rtol=0.) + + + def assert_weight_allclose_Q(weight, W): + # Test that weight is equal to the Q part of the QR decomposition of W + # (or of its transpose if the matrix is wide) + wide_matrix = W.size(-2) < W.size(-1) + if wide_matrix: + W = W.transpose(-2, -1) + Q, R = torch.linalg.qr(W) + Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + if wide_matrix: + Q = Q.transpose(-2, -1) + torch.testing.assert_allclose(Q, weight, atol=1e-5, rtol=0.) + + + for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide + (torch.float32, torch.complex64), + (True, False)): + # Conv2d does not support complex yet + if not use_linear and dtype.is_complex: + continue + + if use_linear: + input = torch.randn(3, shape[0], dtype=dtype) + else: + input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) + + for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"), + (False, True)): + # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False + # See Note [right_inverse expm cayley] + can_initialize = use_trivialization or parametrization == "householder" + + # We generate them every time to always start with fresh weights + if use_linear: + m = nn.Linear(*shape, dtype=dtype) + else: + m = nn.Conv2d(2, 3, shape, dtype=dtype) + + # We do not support householder for complex inputs + # See Note [Householder complex] + w_init = m.weight.clone() + if parametrization == "householder" and m.weight.is_complex(): + msg = "householder parametrization does not support complex tensors" + with self.assertRaisesRegex(ValueError, msg): + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + continue + + wide_matrix = w_init.size(-2) < w_init.size(-1) + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + # Forwards works as expected + self.assertEqual(w_init.shape, m.weight.shape) + assert_is_orthogonal(m.weight) + if can_initialize: + assert_weight_allclose_Q(m.weight, w_init) + + # Intializing with a given orthogonal matrix works + X = torch.randn_like(m.weight) + if wide_matrix: + X = X.transpose(-2, -1) + w_new = torch.linalg.qr(X).Q + if wide_matrix: + w_new = w_new.transpose(-2, -1) + if can_initialize: + m.weight = w_new + torch.testing.assert_allclose(w_new, m.weight, atol=1e-5, rtol=0.) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix + w_new = torch.randn_like(m.weight) + if can_initialize: + m.weight = w_new + assert_weight_allclose_Q(m.weight, w_new) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + opt = torch.optim.SGD(m.parameters(), lr=0.1) + for _ in range(2): + opt.zero_grad() + m(input).norm().backward() + grad = m.parametrizations.weight.original.grad + self.assertIsNotNone(grad) + # We do not update the upper triangular part of the matrix if tall tril if wide + if grad.size(-2) >= grad.size(-1): + zeros_grad = grad.triu(1) + else: + zeros_grad = grad.tril(-1) + self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) + # The gradient in the diagonal can only be imaginary because a skew-Hermitian + # matrix has imaginary diagonal + diag_grad = grad.diagonal(dim1=-2, dim2=-1) + if grad.is_complex(): + diag_grad = diag_grad.real + self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) + opt.step() + assert_is_orthogonal(m.weight) + + def test_orthogonal_errors(self): + m = nn.Linear(3, 4) + with self.assertRaisesRegex(ValueError, "has to be one of"): + torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") + + with self.assertRaisesRegex(ValueError, "Expected a matrix"): + torch.nn.utils.parametrizations.orthogonal(m, "bias") + + torch.nn.utils.parametrizations.orthogonal(m, "weight") + with self.assertRaisesRegex(ValueError, "matrices of shape"): + m.weight = torch.randn(5, 5) + torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + + def test_threshold_int(self): x = torch.tensor([-3, -2, -1, 0, 1, 2, 3]) expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) @@ -4717,7 +4853,7 @@ def fc_op(X, W, b): packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor) actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor) expected_output = fc_op(X, W, b) - torch.testing.assert_allclose(expected_output, actual_output.cpu(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3) def test_embeddingbag_from_pretrained(self): a = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) @@ -5462,6 +5598,92 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, self.assertEqual(mm[0].param[0].item(), 10) self.assertEqual(mm[0].sub.weight[0, 0].item(), 555) + def test_extra_state(self): + + class SubModule(torch.nn.Module): + def __init__(self, foo): + super().__init__() + self.foo = foo + + def get_extra_state(self): + return { + 'foo': self.foo + } + + def set_extra_state(self, state): + self.foo = state['foo'] + + class MyModule(torch.nn.Module): + def __init__(self, foo, bar): + super().__init__() + self.sub = SubModule(foo) + self.bar = bar + + def get_extra_state(self): + return { + 'bar': self.bar + } + + def set_extra_state(self, state): + self.bar = state['bar'] + + # Ensure state_dict contains the extra state by loading it into another module. + m = MyModule(3, 'something') + m2 = MyModule(5, 'something else') + m2.load_state_dict(m.state_dict()) + self.assertEqual(m.state_dict(), m2.state_dict()) + self.assertEqual(m2.bar, m.bar) + self.assertEqual(m2.sub.foo, m.sub.foo) + + def test_extra_state_non_dict(self): + + class MyModule(torch.nn.Module): + def __init__(self, foo): + super().__init__() + self.foo = foo + + def get_extra_state(self): + return self.foo + + def set_extra_state(self, state): + self.foo = state + + # Test various types of extra state. + for state in ('something', 5, MyModule(3)): + m = MyModule(state) + m2 = MyModule('something else') + m2.load_state_dict(m.state_dict()) + self.assertEqual(m.state_dict(), m2.state_dict()) + self.assertEqual(m.foo, m2.foo) + + def test_extra_state_missing_set_extra_state(self): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def get_extra_state(self): + return { + 'foo': 5 + } + + m = MyModule() + with self.assertRaisesRegex(RuntimeError, 'Unexpected key'): + m.load_state_dict(m.state_dict()) + + def test_extra_state_missing_get_extra_state(self): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def set_extra_state(self): + pass + + m = MyModule() + with self.assertRaisesRegex(RuntimeError, 'Missing key'): + m.load_state_dict(m.state_dict()) + def test_parameter_assignment(self): l = nn.Linear(5, 5) @@ -6097,6 +6319,37 @@ def test_MaxUnpool2d_output_size(self): else: self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w))) + def test_max_unpool2d_nhwc_cpu(self): + input = torch.randn(2, 10, 9, 9).float().cpu() + input = input.contiguous(memory_format=torch.channels_last) + ref_input = input.clone().contiguous() + + pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() + ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() + + out, ind = pool(input) + ref_out, ref_ind = ref_pool(ref_input) + out.requires_grad_() + ref_out.requires_grad_() + + unpool = nn.MaxUnpool2d(3, stride=2).cpu() + ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu() + + upout = unpool(out, ind) + ref_upout = ref_unpool(ref_out, ref_ind) + + grad = torch.randn(upout.size()).float().cpu() + grad = grad.contiguous(memory_format=torch.channels_last) + ref_grad = grad.clone().contiguous() + + upout.backward(grad) + ref_upout.backward(ref_grad) + + self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_upout.is_contiguous()) + self.assertTrue(torch.allclose(upout, ref_upout)) + self.assertTrue(torch.allclose(out.grad, ref_out.grad)) + def test_container_copy(self): class Model(nn.Module): def __init__(self): @@ -6797,8 +7050,7 @@ def perm_fn(x): encoder_input = torch.tensor([[[20., 30., 40., 50.]]]) result = model(encoder_input) ref_output = torch.tensor([[[2.249815, 0.131006, -0.702199, 0.177868]]]) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) # deterministic input encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], @@ -6806,8 +7058,7 @@ def perm_fn(x): result = model(encoder_input) ref_output = perm_fn(torch.tensor([[[2.264103, 0.121417, -0.696012, 0.159724]], [[2.264103, 0.121417, -0.696012, 0.159724]]])) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) # deterministic input encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891], @@ -6831,8 +7082,7 @@ def perm_fn(x): [2.4237977, 0.03290575, -0.60561789, -0.05940082]], [[2.41383916, 0.02686345, -0.61256377, -0.06380707], [2.42000277, 0.03800944, -0.60824798, -0.04754947]]])) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) def test_transformerdecoderlayer(self): # this is a deterministic test for TransformerDecoderLayer @@ -7013,8 +7263,7 @@ def perm_fn(x): memory_input = torch.tensor([[[60., 70., 80., 90.]]]) result = model(decoder_input, memory_input) ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) # deterministic input decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], @@ -7023,8 +7272,7 @@ def perm_fn(x): result = model(decoder_input, memory_input) ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]], [[2.415448, 0.054389, -0.610932, -0.0156613]]])) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) # deterministic input decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], @@ -7034,8 +7282,7 @@ def perm_fn(x): result = model(decoder_input, memory_input) ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]], [[2.338531, 0.087709, -0.65776, 0.080646]]])) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) # deterministic input decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], @@ -7061,8 +7308,7 @@ def perm_fn(x): [2.42216881, 0.03586554, -0.6067524, -0.05289126]], [[2.42205716, 0.03488046, -0.60683681, -0.05460596], [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) - self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output) + torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) def test_transformerencoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): @@ -7130,13 +7376,13 @@ def perm_fn(x): [2.422901, 0.024187, -0.606178, -0.074929]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # all 0 mask = torch.zeros([2, 5]).to(device) == 1 result = model(encoder_input, src_key_padding_mask=mask) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) mask[0, 1] = 1 mask[1, 3] = 1 mask[1, 4] = 1 @@ -7153,7 +7399,7 @@ def perm_fn(x): [2.4242, 0.024653, -0.605266, -0.074959]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # test case 2, multiple layers no norm model = nn.TransformerEncoder(encoder_layer, 2).to(device) @@ -7170,7 +7416,7 @@ def perm_fn(x): [2.419075, 0.017449, -0.608722, -0.085014]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) model = nn.TransformerEncoder(encoder_layer, 6).to(device) result = model(encoder_input, src_key_padding_mask=mask) @@ -7186,7 +7432,7 @@ def perm_fn(x): [2.419101, 0.017453, -0.608704, -0.085025]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # test case 3, multiple layers with norm # d_model = 4 @@ -7205,7 +7451,7 @@ def perm_fn(x): [1.695952, -0.357637, -0.893065, -0.445251]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) model = nn.TransformerEncoder(encoder_layer, 6, norm=norm).to(device) result = model(encoder_input, src_key_padding_mask=mask) @@ -7221,7 +7467,7 @@ def perm_fn(x): [1.695955, -0.357639, -0.893051, -0.445265]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) def test_transformerdecoder(self): @@ -7271,7 +7517,7 @@ def perm_fn(x): ref_output = torch.tensor( [[[2.314351, 0.094805, -0.671322, 0.101977]]]).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) # deterministic input decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], @@ -7282,7 +7528,7 @@ def perm_fn(x): [[2.422245, 0.051716, -0.606338, -0.024756]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) # deterministic input decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], @@ -7294,7 +7540,7 @@ def perm_fn(x): [[2.343536, 0.085561, -0.654954, 0.074991]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) # deterministic input decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], @@ -7324,7 +7570,7 @@ def perm_fn(x): [2.432306, 0.028858, -0.599542, -0.072846]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # key_padding_mask key_padding_mask = torch.zeros(2, 3).to(device) == 1 @@ -7338,7 +7584,7 @@ def perm_fn(x): [2.432306, 0.028858, -0.599542, -0.072846]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # key_padding_mask key_padding_mask[0, 2] = 1 @@ -7354,7 +7600,7 @@ def perm_fn(x): [2.432659, 0.029244, -0.599294, -0.072382]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # memory_key_padding_mask key_padding_mask = torch.zeros(2, 5).to(device) == 1 @@ -7368,7 +7614,7 @@ def perm_fn(x): [2.432306, 0.028858, -0.599542, -0.072846]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # memory_key_padding_mask key_padding_mask[0, 4] = 1 @@ -7385,7 +7631,7 @@ def perm_fn(x): [2.433075, 0.028543, -0.598987, -0.073985]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # multiple layers no norm model = nn.TransformerDecoder(decoder_layer, 2).to(device) @@ -7397,7 +7643,7 @@ def perm_fn(x): ref_output = torch.tensor( [[[2.31316, 0.0950293, -0.671995, 0.102802]]]).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) # multiple layers no norm model = nn.TransformerDecoder(decoder_layer, 6).to(device) @@ -7430,7 +7676,7 @@ def perm_fn(x): [2.43113, 0.0279516, -0.600376, -0.0736896]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # multiple layers with norm # d_model = 4 @@ -7444,7 +7690,7 @@ def perm_fn(x): ref_output = torch.tensor( [[[1.66166, -0.326986, -1.01466, -0.320017]]]).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) # multiple layers with norm model = nn.TransformerDecoder(decoder_layer, 6, norm=norm).to(device) @@ -7477,7 +7723,7 @@ def perm_fn(x): [1.69571, -0.357363, -0.894154, -0.444196]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) # gelu activation test cases activation = "gelu" @@ -7495,7 +7741,7 @@ def perm_fn(x): result = model(decoder_input, memory_input) ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-3) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3) # deterministic input decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]], @@ -7505,7 +7751,7 @@ def perm_fn(x): ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]], [[2.415448, 0.054389, -0.610932, -0.0156613]]])).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) # deterministic input decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]], @@ -7516,7 +7762,7 @@ def perm_fn(x): ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]], [[2.338531, 0.087709, -0.65776, 0.080646]]])).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-4) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4) # deterministic input decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034], @@ -7546,7 +7792,7 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]] )).to(device) self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) - torch.testing.assert_allclose(result, ref_output, rtol=1e-7, atol=1e-5) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available') def test_cudnn_rnn_dropout_states_device(self): @@ -8936,6 +9182,25 @@ def helper(self, size): helper(self, (4, 1, 9, 9)) helper(self, (4, 9, 1, 1)) + def test_batchnorm_non_contig_cpu(self): + input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu() + input = input.permute(0, 2, 1, 3) + + bn = torch.nn.BatchNorm2d(2).cpu().float().eval() + bn.weight.data.uniform_() + bn.bias.data.uniform_() + + ref_input = input.detach().clone().contiguous() + ref_bn = nn.BatchNorm2d(2).cpu().float().eval() + ref_bn.load_state_dict(bn.state_dict()) + + out = bn(input) + ref_out = ref_bn(ref_input) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @skipIfRocm @@ -9141,9 +9406,9 @@ def test_cosine_embedding_loss_with_diff_type(self): input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) target = torch.tensor([1, -1], dtype=torch.int, device=device) expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target) - for dt1 in torch.testing.get_all_math_dtypes(device): - for dt2 in torch.testing.get_all_math_dtypes(device): - for dt3 in torch.testing.get_all_math_dtypes(device): + for dt1 in get_all_math_dtypes(device): + for dt2 in get_all_math_dtypes(device): + for dt3 in get_all_math_dtypes(device): # dt3 is used as dtype for target = [1, -1], so let's skip unsigned type if dt3 == torch.uint8: continue @@ -9160,7 +9425,7 @@ def test_kl_div_with_diff_type(self): input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device) expected = torch.nn.functional.kl_div(input, target) - for input_dtype in torch.testing.get_all_math_dtypes(device): + for input_dtype in get_all_math_dtypes(device): if input_dtype.is_complex: continue for target_dtype in [torch.float32, torch.float64, torch.float16]: @@ -9176,7 +9441,7 @@ def test_kl_div_with_diff_type_log_target(self): input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device) target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device).log() expected = torch.nn.functional.kl_div(input, target, log_target=True) - for input_dtype in torch.testing.get_all_math_dtypes(device): + for input_dtype in get_all_math_dtypes(device): if input_dtype.is_complex: continue for target_dtype in [torch.float32, torch.float64, torch.float16]: @@ -9319,7 +9584,7 @@ def _input_grad(input, target, reduction): return input.grad for device, dtype, reduction in product(device_(), - torch.testing.integral_types(), + integral_types(), ('none', 'sum', 'mean')): input = torch.randn(2, 2, device=device, requires_grad=True) target = torch.randint(0, 9, (2, 2), device=device, dtype=dtype) @@ -9352,25 +9617,6 @@ def test_huber_loss_zero_delta(): test_huber_loss_zero_delta() def test_cosine_similarity(self): - input1 = torch.randn(4, 4, requires_grad=True) - input2 = torch.randn(4, 4, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y), (input1, input2))) - - input1 = torch.randn(4, 5, 6, requires_grad=True) - input2 = torch.randn(4, 5, 6, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2))) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - - input1 = torch.randn((), requires_grad=True) - input2 = torch.randn((), requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2))) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - - # Check broadcasting - input1 = torch.randn(2, 1, 3, requires_grad=True) - input2 = torch.randn(1, 2, 3, requires_grad=True) - self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2))) - # Check cosine_similarity input/output shapes input_size = (1, 3, 2, 1) expected_size = (1, 2, 1) @@ -9397,7 +9643,6 @@ def test_cosine_similarity(self): with self.assertRaises(RuntimeError): F.cosine_similarity(input1, input2) - # Check type promotion, issue #61454 input = torch.tensor(12.) out = F.cosine_similarity(input.to(torch.int8), input, dim=-1) @@ -10374,6 +10619,13 @@ def test_upsamplingTrilinear3d_spatial_invariance(self): out_t_5 = m(in_t_9[:, :, :5, :5, :5]) self.assertEqual(out_t_9[:, :, :15, :15, :15], out_t_5) + def test_upsampling_small_scale(self): + m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear") + in_t = torch.arange(1, 5, dtype=torch.float64).reshape(1, 1, 2, 2) + out_t = m(in_t) + expected_out_t = torch.tensor([[[[2.5]]]]) + self.assertEqual(expected_out_t, out_t) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_interpolate_illegal_memory_access(self): in_s = 45 @@ -11053,6 +11305,48 @@ def test_convert_sync_batchnorm(self): self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device) self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key]) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sync_batchnorm_backward_elemt(self): + device = 'cuda' + saved_input = torch.rand(2, 3, 2, 1, device=device) + grad_output = torch.rand(2, 3, 2, 1, device=device) + mean = torch.rand(3, device=device) + invstd = torch.rand(3, device=device) + weight = torch.rand(3, device=device) + sum_dy = torch.rand(3, device=device) + sum_dy_xmu = torch.rand(3, device=device) + count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device) + + gI_contiguous = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + + # Test batch_norm_backward_elemt gives the same answer for all + # combinations of contiguous as channels_last input + for a, b in [ + (torch.channels_last, torch.contiguous_format), + (torch.contiguous_format, torch.channels_last), + (torch.channels_last, torch.channels_last), + ]: + gI_actual = torch.batch_norm_backward_elemt( + grad_output.contiguous(memory_format=a), + saved_input.contiguous(memory_format=b), + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + self.assertEqual(gI_actual, gI_contiguous) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_sync_batchnorm_accuracy_cuda(self): # The target of this test is to test the functionality and accuracy of @@ -11175,7 +11469,7 @@ def test_layer_norm_grads_with_create_graph_flag(self): grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0] grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0] - self.assertTrue(torch.allclose(grads1, grads2, rtol, atol)) + self.assertEqual(grads1, grads2, rtol=rtol, atol=atol) if TEST_CUDA: x = x.to('cuda') @@ -11184,7 +11478,7 @@ def test_layer_norm_grads_with_create_graph_flag(self): grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0] grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0] - self.assertTrue(torch.allclose(grads1, grads2, rtol, atol)) + self.assertEqual(grads1, grads2, rtol=rtol, atol=atol) def test_padding_list(self): # Padding can be a list, or tuple (regression test for gh-54452) @@ -11692,7 +11986,7 @@ def test_add_relu(self): relu_res = torch.relu(add_res) add_relu_res = torch._VF._add_relu(a, b) - self.assertTrue(torch.allclose(add_relu_res, relu_res)) + self.assertEqual(add_relu_res, relu_res) def test_add_relu_broadcasting(self): a = torch.rand((1, 32)) @@ -11701,7 +11995,7 @@ def test_add_relu_broadcasting(self): res = torch._VF._add_relu(a, b) broadcasted_res = torch._VF._add_relu(a, b_scalar) - self.assertTrue(torch.allclose(broadcasted_res, res)) + self.assertEqual(broadcasted_res, res) def add_test(test, decorator=None): @@ -12883,7 +13177,7 @@ def test_Dropout(self, device): self._test_dropout_stride_mean_preserve(nn.Dropout, device) - if self.device_type == 'cuda': + if self.device_type == 'cuda' or self.device_type == 'cpu': input = input.bfloat16() self._test_dropout(nn.Dropout, device, input) @@ -13033,6 +13327,40 @@ def test_GroupNorm_empty(self, device): with torch.backends.cudnn.flags(enabled=False): self._test_module_empty_input(mod, inp) + @onlyCPU + @dtypes(torch.float, torch.double) + def test_groupnorm_nhwc(self, device, dtype): + def helper(self, size, groups): + channels = size[1] + input = torch.randn(size, dtype=dtype, device=device, requires_grad=True) + input = input.contiguous(memory_format=torch.channels_last) + input.retain_grad() + grad = torch.randn(size, dtype=dtype, device=device) + grad = grad.contiguous(memory_format=torch.channels_last) + gn = nn.GroupNorm(groups, channels).to(device).to(dtype) + gn.weight.data.uniform_() + gn.bias.data.uniform_() + + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype) + ref_gn.load_state_dict(gn.state_dict()) + + out = gn(input) + out.backward(grad) + ref_out = ref_gn(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(gn.weight.grad, ref_gn.weight.grad) + self.assertEqual(gn.bias.grad, ref_gn.bias.grad) + self.assertEqual(input.grad, ref_input.grad) + + helper(self, (4, 8, 10, 10), 4) + helper(self, (2, 30, 9, 9), 3) + @onlyOnCPUAndCUDA def test_GroupNorm_numeric(self, device): def group_norm_ref(X, gamma, beta, groups, channels, eps): @@ -13935,14 +14263,17 @@ def helper(n, c, h, w, kernel_size, stride=None, self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(ref_out.is_contiguous()) - self.assertTrue(torch.allclose(out, ref_out)) - self.assertTrue(torch.allclose(input.grad, ref_input.grad)) + self.assertEqual(out, ref_out) + self.assertEqual(input.grad, ref_input.grad) helper(4, 8, 8, 8, 3) helper(4, 8, 8, 8, 3, count_include_pad=False, padding=1) helper(4, 8, 8, 8, 3, count_include_pad=False, padding=2, stride=2) helper(4, 8, 8, 8, 3, divisor_override=42) helper(4, 8, 8, 8, 7) + # ROCm 16GB MI25 hits OOM error. Clear caching allocator prior to running large subtest. + if TEST_WITH_ROCM and 'cuda' in device: + torch.cuda.empty_cache() helper(200, 512, 28, 28, 2) helper(4, 8, 7, 7, 3, stride=1) helper(4, 8, 7, 7, 3, padding=2, stride=1) @@ -14062,9 +14393,9 @@ def helper(n, c, h, w, kernel_size, stride=None): self.assertTrue(ref_out.is_contiguous()) self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(ref_ind.is_contiguous()) - self.assertTrue(torch.allclose(out, ref_out)) - self.assertTrue(torch.allclose(ind, ref_ind)) - self.assertTrue(torch.allclose(input.grad, ref_input.grad)) + self.assertEqual(out, ref_out) + self.assertEqual(ind, ref_ind) + self.assertEqual(input.grad, ref_input.grad) helper(4, 8, 8, 8, 7) helper(200, 512, 28, 28, 2) @@ -17007,6 +17338,78 @@ def test_cross_entropy_loss_one_hot_target(self, device): output_one_hot = m(input, target_one_hot) self.assertEqual(output, output_one_hot) + def test_cross_entropy_label_smoothing_errors(self, device): + N, C = 3, 4 + input_args = [ + (torch.randn((N, C), device=device), torch.arange(0, C, device=device)), + (torch.randn((N, C), device=device), torch.randn(N, C, device=device)) + ] + for input_arg in input_args: + loss = nn.CrossEntropyLoss(label_smoothing=1.2) + with self.assertRaisesRegex(RuntimeError, + r"label_smoothing must be between 0\.0"): + loss(*input_arg) + + def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): + N, C = 10, 4 + ks = range(5) + reductions = ['none', 'mean', 'sum'] + label_smoothings = [0.05, 0.15] + + for k, reduction, label_smoothing in product(ks, reductions, label_smoothings): + other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] + input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) + target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) + + # construct target probablity that should have the same result as label_smoothing + target_proba = F.one_hot(target, num_classes=C) + # Need to put the C dim at index 1. + target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1)) + target_mask = (target_proba == 1) + target_proba = target_proba.to(dtype=input.dtype) + + # y_k^ls = y_k * (1 - label_smoothing) + label_smoothing / n_classes + # Get one-hot representation of the target. + target_proba.masked_fill_(target_mask, 1 - label_smoothing + label_smoothing / C) + target_proba.masked_fill_(~target_mask, label_smoothing / C) + + loss = nn.CrossEntropyLoss(reduction=reduction) + output_with_prob = loss(input, target_proba) + + loss = nn.CrossEntropyLoss( + reduction=reduction, label_smoothing=label_smoothing) + output_with_index = loss(input, target) + + self.assertEqual(output_with_prob, output_with_index, + rtol=1e-07, atol=1e-05) + + def test_cross_entropy_label_smoothing_with_probs(self, device): + N, C = 10, 4 + ks = range(5) + reductions = ['none', 'mean', 'sum'] + label_smoothings = [0.05, 0.15] + + # Test with k-dimensional loss. + for k, label_smoothing in product(ks, label_smoothings): + other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] + input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) + target = F.log_softmax(torch.randn(N, C, *other_dims, device=device), dim=1) + + for reduction in reductions: + # use with label_smoothing + loss = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing) + output_with_smoothing = loss(input, target) + + # manually smoothing target + # class_proba^ls = class_proba * (1 - label_smoothing) + + # label_smoothing / n_classes + target_with_smoothing = target * (1 - label_smoothing) + label_smoothing / C + loss = nn.CrossEntropyLoss(reduction=reduction) + output_with_manual_smoothing = loss(input, target_with_smoothing) + + self.assertEqual(output_with_smoothing, output_with_manual_smoothing) + + def test_softshrink_negative(self, device): input = torch.randn(5, device=device, requires_grad=True) m = torch.nn.Softshrink(-1) @@ -17015,14 +17418,30 @@ def test_softshrink_negative(self, device): m(input) def test_fold(self, device): + def test_dtype(fn, input, dtype): + input = input.detach().clone().to(dtype=dtype).requires_grad_(True) + input2 = input.detach().clone().float().requires_grad_(True) + out = fn(input) + out.sum().backward() + out2 = fn(input2) + out2.sum().backward() + self.assertEqual(out.dtype, dtype) + self.assertEqual(input.grad.dtype, dtype) + self.assertEqual(out, out2.to(dtype=dtype), atol=0.05, rtol=0) + self.assertEqual(input.grad, input2.grad.to(dtype=dtype)) + def func(x): return F.fold(x, output_size=(4, 5), kernel_size=(2, 2)) + seeds = (44, 83, 71, 25, 999) for sd in seeds: torch.manual_seed(sd) x = torch.randn(1, 12, 12, device=device, requires_grad=True) gradcheck(func, [x]) gradgradcheck(func, [x]) + if device == 'cpu': + test_dtype(func, x, torch.bfloat16) + def test_logsigmoid_out(self, device): # this isn't actually documented, but was broken previously: @@ -17042,7 +17461,7 @@ def test_maxpool3d_non_square_backward(self, device): shape = tuple(32 if i != dim else 256 for i in range(4)) x = torch.randn(shape, device=device, requires_grad=True) F.max_pool3d(x, kernel_size=(1, 1, 1)).sum().backward() - self.assertTrue(torch.allclose(x.grad, torch.ones_like(x.grad))) + self.assertEqual(x.grad, torch.ones_like(x.grad)) # Check that clip_grad_norm_ raises an error if the total norm of the # parameters' gradients is non-finite @@ -17534,7 +17953,7 @@ def removable_hook_2(m, input): input = torch.randn(2, 2) output = module(input) - self.assertTrue(torch.allclose(torch.sigmoid(input), output)) + self.assertEqual(torch.sigmoid(input), output) # make sure hook removal is successful self.assertFalse(handle.id in handle.hooks_dict_ref()) @@ -17569,7 +17988,7 @@ def removable_hook_2(m, input, output): input = torch.randn(2, 2) output = module(input) - self.assertTrue(torch.allclose(torch.sigmoid(input), output)) + self.assertEqual(torch.sigmoid(input), output) # make sure hook removal is successful self.assertFalse(handle.id in handle.hooks_dict_ref()) @@ -17863,7 +18282,7 @@ def hook_function(module, input): module = TestModule() module.register_forward_pre_hook(hook_function) output = module(torch.zeros(2, 2)) - self.assertTrue(torch.allclose(output, torch.ones(2, 2))) + self.assertEqual(output, torch.ones(2, 2)) def test_lazy_forward_hook(self): """ @@ -17886,7 +18305,7 @@ def hook_function(module, input, output): module = TestModule() module.register_forward_hook(hook_function) output = module(torch.zeros(2, 2)) - self.assertTrue(torch.allclose(output, torch.ones(2, 2))) + self.assertEqual(output, torch.ones(2, 2)) @suppress_warnings def test_lazy_conv1d(self): diff --git a/test/test_nnapi.py b/test/test_nnapi.py index 19efa7f0ae738..f8db7e1a3df90 100644 --- a/test/test_nnapi.py +++ b/test/test_nnapi.py @@ -49,6 +49,7 @@ def check( convert_args=None, atol_rtol=None, limit=None, + expected_memory_format=None ): with torch.no_grad(): if isinstance(arg_or_args, torch.Tensor): @@ -76,6 +77,8 @@ def check( # Too many mismatches. Re-run the check with no tolerance # to get a nice message. self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) + if expected_memory_format: + self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format)) def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): torch.manual_seed(29) @@ -319,6 +322,28 @@ def forward(self, lhs, rhs): torch.tensor([[3.0, 4.0], [5.0, 6.0]]), ]) + def test_pointwise_binary_const(self): + const = torch.randn(1, 4, 6, 6) + + class ArgPlusConst(torch.nn.Module): + def forward(self, arg): + return arg + const + + class ConstPlusArg(torch.nn.Module): + def forward(self, arg): + return const + arg + + arg_contig = torch.randn(2, 4, 6, 6) + arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) + + for mod_class in [ArgPlusConst, ConstPlusArg]: + for use_nhwc in [False, True]: + with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): + arg = arg_nhwc if use_nhwc else arg_contig + memory_format = torch.channels_last if use_nhwc else torch.contiguous_format + self.check(mod_class(), arg, + expected_memory_format=memory_format) + def test_hardtanh(self): inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) self.check(torch.nn.Hardtanh(), inp) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index be46f93bdf3a8..a6f5be036c7a6 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -7,6 +7,7 @@ (TestCase, run_tests) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes) +from torch.testing._internal.common_dtype import get_all_dtypes # For testing handling NumPy objects and sending tensors to / accepting # arrays from NumPy. @@ -393,7 +394,7 @@ def test_has_storage_numpy(self, device): self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage()) self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage()) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_numpy_scalar_cmp(self, device, dtype): if dtype.is_complex: tensors = (torch.tensor(complex(1, 3), dtype=dtype, device=device), diff --git a/test/test_ops.py b/test/test_ops.py index 76a7b6a1485ca..a9d470fec5e44 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,16 +1,17 @@ from collections.abc import Sequence from functools import partial, wraps +import unittest import warnings import torch -from torch.testing import \ - (FileCheck, floating_and_complex_types_and, get_all_dtypes) +from torch.testing import FileCheck, make_tensor +from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes from torch.testing._internal.common_utils import \ - (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor, + (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings) from torch.testing._internal.common_methods_invocations import \ - (op_db, _NOTHING, UnaryUfuncInfo, SpectralFuncInfo) + (op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo) from torch.testing._internal.common_device_type import \ (deviceCountAtLeast, instantiate_device_type_tests, ops, onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes) from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference @@ -27,8 +28,8 @@ # Get names of all the operators which have ref in their entry in OpInfo (testing infra) # except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py) # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) -_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, SpectralFuncInfo)) and - op.ref is not None and op.ref is not _NOTHING, op_db)) +_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo, + SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db)) # Tests that apply to all operators and aren't related to any particular @@ -684,6 +685,7 @@ class TestJit(JitCommonTestCase): # and runtimes (eager, traced, scripted). # TODO WARNING: inplace x {traced, scripted} not currently tested @_variant_ops(op_db) + @unittest.skipIf(True, "Temporarily skipping while landing Union PR stack") def test_variant_consistency_jit(self, device, dtype, op): _requires_grad = op.supports_autograd and (dtype.is_floating_point or op.supports_complex_autograd(torch.device(device).type)) diff --git a/test/test_optim.py b/test/test_optim.py index 20b8e5c443de5..d69e9351d33a0 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -12,8 +12,8 @@ from torch.autograd import Variable from torch import sparse from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \ - MultiStepLR, WarmUpLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ - _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR + MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ + _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler from torch.optim.swa_utils import AveragedModel, SWALR, update_bn from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ skipIfRocm @@ -274,16 +274,16 @@ def test_sgd(self): ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3), - [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")] + [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)] ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3), - [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)] ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3), [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)] + lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4)] ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3), @@ -430,18 +430,18 @@ def test_adam(self): lambda weight, bias: optimizer( self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3), - [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")] + [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)] ) self._test_basic_cases( lambda weight, bias: optimizer( self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3), - [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)] ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True), - [lambda opt: ExponentialLR(opt, gamma=0.9), - lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), + lambda opt: ExponentialLR(opt, gamma=0.9)] ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True), @@ -992,12 +992,12 @@ def test_exponential_lr_is_constant_for_constant_epoch(self): scheduler = ExponentialLR(self.opt, gamma=0.9) self._test_lr_is_constant_for_constant_epoch(scheduler) - def test_constant_warmup_lr_is_constant_for_constant_epoch(self): - scheduler = WarmUpLR(self.opt, warmup_method="constant") + def test_constantlr_is_constant_for_constant_epoch(self): + scheduler = ConstantLR(self.opt) self._test_lr_is_constant_for_constant_epoch(scheduler) - def test_linear_warmup_lr_is_constant_for_constant_epoch(self): - scheduler = WarmUpLR(self.opt, warmup_method="linear") + def test_linear_linearlr_is_constant_for_constant_epoch(self): + scheduler = LinearLR(self.opt) self._test_lr_is_constant_for_constant_epoch(scheduler) def test_step_lr(self): @@ -1051,76 +1051,78 @@ def test_multi_step_lr_with_epoch(self): scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test_with_epoch(scheduler, targets, epochs) - def test__get_last_lr_constant_warmup_lr(self): + def test_get_last_lr_constantlr(self): # lr = 0.025 if epoch < 5 # lr = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 5 + [0.05] * 5 targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) self._test_get_last_lr(scheduler, targets, epochs) - def test__get_last_lr_linear_warmup_lr(self): + def test_get_last_lr_linearlr(self): # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 # lr = 0.04375 if epoch == 3 # lr = 0.005 if 4 <= epoch epochs = 10 - factor = 1.0 / 2 + start_factor = 1.0 / 4 + end_factor = 3. / 5 iters = 4 - interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (epochs - iters) targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + scheduler = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) self._test_get_last_lr(scheduler, targets, epochs) - def test__constant_warmup_lr(self): + def test_constantlr(self): # lr = 0.025 if epoch < 5 # lr = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 5 + [0.05] * 5 targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) self._test(scheduler, targets, epochs) - def test__linear_warmup_lr(self): + def test_linearlr(self): # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 # lr = 0.04375 if epoch == 3 # lr = 0.005 if 4 <= epoch epochs = 10 - factor = 1.0 / 2 + start_factor = 1.0 / 2 iters = 4 - interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] + interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test(scheduler, targets, epochs) - def test_constant_warmup_with_epoch(self): + def test_constantlr_with_epoch(self): # lr = 0.025 if epoch < 5 # lr = 0.005 if 5 <= epoch epochs = 10 single_targets = [0.025] * 5 + [0.05] * 5 targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) self._test_with_epoch(scheduler, targets, epochs) - def test_linear_warmup_with_epoch(self): + def test_linearlr_with_epoch(self): # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 # lr = 0.04375 if epoch == 3 # lr = 0.005 if 4 <= epoch epochs = 10 - factor = 1.0 / 2 + start_factor = 1.0 / 2 + end_factor = 1. iters = 4 - interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] + interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test_with_epoch(scheduler, targets, epochs) def test_exp_lr(self): @@ -1145,14 +1147,14 @@ def test_closed_form_step_lr(self): closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) self._test_against_closed_form(scheduler, closed_form_scheduler, 20) - def test_closed_form_linear_warmup_lr(self): - scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear") - closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear") + def test_closed_form_linearlr(self): + scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) + closed_form_scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) self._test_against_closed_form(scheduler, closed_form_scheduler, 20) - def test_closed_form_constant_warmup_lr(self): - scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant") - closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant") + def test_closed_form_constantlr(self): + scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) + closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) self._test_against_closed_form(scheduler, closed_form_scheduler, 20) def test_closed_form_multi_step_lr(self): @@ -1253,6 +1255,44 @@ def test_reduce_lr_on_plateau8(self): threshold=0.1, patience=5, cooldown=5) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) + def test_chained_lr1(self): + epochs = 10 + schedulers = [None] * 1 + targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3] + schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) + scheduler = ChainedScheduler(schedulers) + self._test([scheduler], targets, epochs) + + def test_chained_lr2(self): + epochs = 10 + schedulers = [None] * 1 + targets = [[0.02, 0.03, 0.04] + [0.05] * 9] + schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) + scheduler = ChainedScheduler(schedulers) + self._test([scheduler], targets, epochs) + + def test_chained_lr3(self): + epochs = 10 + schedulers = [None] * 2 + targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3] + schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) + schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) + scheduler = ChainedScheduler(schedulers) + self._test([scheduler], targets, epochs) + + def test_chained_lr4(self): + epochs = 9 + schedulers = [None] * 3 + targets = [[0.05 * 0.2 * 0.9 ** x for x in range(3)] + + [0.05 * 0.2 * 0.9 ** 3 * 0.1] + + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)] + + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]] + schedulers[0] = ExponentialLR(self.opt, gamma=0.9) + schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) + schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) + scheduler = ChainedScheduler(schedulers) + self._test([scheduler], targets, epochs) + def test_compound_step_and_multistep_lr(self): epochs = 10 schedulers = [None] * 2 @@ -1285,20 +1325,23 @@ def test_compound_exp_and_multistep_lr(self): schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) - def test_compound_exp_and_linear_warmup_lr(self): + def test_compound_exp_and_linearlr(self): epochs = 10 iters = 4 - factor = 0.4 + start_factor = 0.4 + end_factor = 0.9 schedulers = [None] * 2 single_targets = [0.05 * (0.9 ** x) for x in range(11)] for i in range(iters): - single_targets[i] *= factor + i / iters * (1 - factor) + single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) + for i in range(iters, 11): + single_targets[i] *= end_factor targets = [single_targets, [x * epochs for x in single_targets]] - schedulers[0] = ExponentialLR(self.opt, gamma=0.9) - schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + schedulers[0] = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) + schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) - def test_compound_step_and_constant_warmup(self): + def test_compound_step_and_constantlr(self): epochs = 10 iters = 4 factor = 0.4 @@ -1306,20 +1349,20 @@ def test_compound_step_and_constant_warmup(self): single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3 targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) - schedulers[1] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant") + schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) self._test(schedulers, targets, epochs) - def test_compound_linear_warmup_and_multistep_lr(self): + def test_compound_linearlr_and_multistep_lr(self): epochs = 10 iters = 4 - factor = 0.4 + start_factor = 0.4 schedulers = [None] * 2 single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 for i in range(iters): - single_targets[i] *= factor + i / iters * (1 - factor) + single_targets[i] *= start_factor + i / iters * (1 - start_factor) targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) - schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test(schedulers, targets, epochs) def test_compound_cosanneal_and_step_lr(self): @@ -1349,20 +1392,20 @@ def test_compound_cosanneal_and_multistep_lr(self): schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test(schedulers, targets, epochs) - def test_compound_cosanneal_and_linear_warmup_lr(self): + def test_compound_cosanneal_and_linearlr(self): epochs = 10 iters = 4 - factor = 0.4 + start_factor = 0.4 eta_min = 1e-10 schedulers = [None] * 2 single_targets = [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 for x in range(epochs)] for i in range(iters): - single_targets[i] *= factor + i / iters * (1 - factor) + single_targets[i] *= start_factor + i / iters * (1 - start_factor) targets = [single_targets, [x * epochs for x in single_targets]] - schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) - schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) + schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) self._test(schedulers, targets, epochs) def test_compound_cosanneal_and_exp_lr(self): @@ -1447,14 +1490,14 @@ def test_compound_reduce_lr_on_plateau4(self): def test_compound_reduce_lr_on_plateau5(self): iters = 4 - factor = 0.4 + start_factor = 0.4 epochs = 22 for param_group in self.opt.param_groups: param_group['lr'] = 0.5 single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 multipliers = [1] * 22 for i in range(iters): - multipliers[i] *= factor + i / iters * (1 - factor) + multipliers[i] *= start_factor + i / iters * (1 - start_factor) single_targets = [x * y for x, y in zip(single_targets, multipliers)] targets = [single_targets] targets = targets[1:] # test runs step before checking lr @@ -1462,7 +1505,7 @@ def test_compound_reduce_lr_on_plateau5(self): schedulers = [None] * 2 schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', mode='min', threshold=0.1) - schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) def test_cycle_lr_invalid_mode(self): diff --git a/test/test_overrides.py b/test/test_overrides.py index 41044376a40f7..a6252374364c2 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -762,6 +762,9 @@ def __bool__(self): def __int__(self): return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,)) + def __len__(self): + return len(self._data) + # unwrap inputs if necessary def unwrap(v): @@ -782,15 +785,15 @@ class TestEinsumOverride(TestCase): def test_wrapper(self): x = Wrapper(torch.randn(5)) y = Wrapper(torch.randn(4)) - self.assertTrue(torch.allclose(torch.einsum('i,j->ij', x, y), - torch.ger(x, y))) + self.assertEqual(torch.einsum('i,j->ij', x, y)._data, + torch.ger(x, y)._data) # in the old einsum interface, `operands` is a list a = Wrapper(torch.randn(2, 3)) b = Wrapper(torch.randn(5, 3, 7)) c = Wrapper(torch.randn(2, 7)) - self.assertTrue(torch.allclose(torch.einsum('ik,jkl,il->ij', [a, b, c]), - torch.nn.functional.bilinear(a, c, b))) + self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data, + torch.nn.functional.bilinear(a, c, b)._data) class TestGradCheckOverride(TestCase): "Test that wrappers work with gradcheck." diff --git a/test/test_pruning_op.py b/test/test_pruning_op.py index 28f31aeabd705..97a499b03ac15 100644 --- a/test/test_pruning_op.py +++ b/test/test_pruning_op.py @@ -50,7 +50,7 @@ def get_reference_result(embedding_weights, mask, indices_type): ref_pruned_weights, ref_compressed_indices_map = get_reference_result( embedding_weights, mask, indices_type) - torch.testing.assert_allclose(pt_pruned_weights, ref_pruned_weights) + torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights) self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map) self.assertEqual(pt_compressed_indices_map.dtype, indices_type) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index ba828e5b3dae7..9f8b79d96958b 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -210,6 +210,7 @@ def test_no_new_bindings(self): "TupleType", "Type", "unify_type_list", + "UnionType", "Use", "Value", "autocast_decrement_nesting", diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index adacc7efb7093..0f5b6b9cbd70e 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -246,6 +246,39 @@ def test_version(self) -> None: x.data.add_(2) self.assertEqual(cur_vc, x._version) + def test_subclass_priority(self) -> None: + class ErrorA(RuntimeError): + pass + + class ErrorB(RuntimeError): + pass + + # The big tests for code coverage are test_precedence_semantics in + # test_overrides.py; this is just to make sure it is wired up at all + # correctly for __torch_dispatch__ + class A(torch.Tensor): + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise ErrorA + + class B(A): + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise ErrorB + + self.assertRaises(ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))) + self.assertRaises(ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))) + self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))) + self.assertRaises(ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))) + def test_format(self) -> None: x = LoggingTensor(torch.ones(1)) s1 = str(x) diff --git a/test/test_reductions.py b/test/test_reductions.py index 42edfb3817ce1..9760eae52813d 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -2,19 +2,26 @@ import numpy as np import math -from typing import Dict, List +from typing import Dict, List, Sequence import random from functools import partial from itertools import product, combinations, permutations import warnings from torch._six import inf, nan +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, + integral_types_and, floating_and_complex_types_and +) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, - IS_WINDOWS, make_tensor) + IS_WINDOWS) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, - onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, precisionOverride) + OpDTypes, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, + onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, ops, precisionOverride) +from torch.testing._internal.common_methods_invocations import ( + ReductionOpInfo, reduction_ops) # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -53,8 +60,249 @@ def _rand_shape(dim, min_size, max_size): shape.append(random.randint(min_size, max_size)) return tuple(shape) +def _reduced_shape(shape, dim=None, keepdim=False): + """Computes the expected reduced shape given dim and keepdim + + Args: + shape: The shape to reduce + dim : The dimensions to reduce + keepdim: If true, reduced dimensions have size 1 in the reduced shape, + otherwise they are removed from the reduced shape. + + Returns: + The reduced shape + """ + if dim is None: + return [1] * len(shape) if keepdim else [] + + # Wrap negative dims + dim = dim if isinstance(dim, Sequence) else [dim] + dim = set(i if i >= 0 else len(shape) + i for i in dim) + + result = [] + for i, size in enumerate(shape): + if i not in dim: + result.append(size) + elif keepdim: + result.append(1) + + return result + class TestReductions(TestCase): + ########################################################################### + # ReductionOpInfo unit tests + ########################################################################### + + def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim): + """Tests output shape for input with ndim and dim and keepdim kwargs""" + shape = torch.randint(2, 5, (ndim,)).tolist() + t = make_tensor(shape, device, torch.float) + args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim)) + result = op(t, *args, **dim_keepdim, **kwargs) + expected_shape = _reduced_shape(shape, **dim_keepdim) + self.assertEqual(result.shape, expected_shape, f""" + expected output shape to be {expected_shape} but got {list(result.shape)} + for input shape {shape} and {dim_keepdim} + """) + + # TODO(@heitorschueroff) combine cases with and without keepdim once + # there's support for a @parametrize decorator. + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_default(self, device, op: ReductionOpInfo): + """Tests that the default dim reduces all dimensions.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_default_keepdim(self, device, op: ReductionOpInfo): + """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_none(self, device, op: ReductionOpInfo): + """Tests that dim=None reduces all dimensions.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, dim=None) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_none_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_single(self, device, op: ReductionOpInfo): + """Tests that dim=i reduces dimension i.""" + self._test_dim_keepdim(op, device, ndim=0, dim=0) + self._test_dim_keepdim(op, device, ndim=1, dim=0) + self._test_dim_keepdim(op, device, ndim=2, dim=-1) + self._test_dim_keepdim(op, device, ndim=3, dim=1) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_single_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=i, when keepdim=True, reduces dimension i to size 1.""" + self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True) + self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True) + self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True) + self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_empty(self, device, op: ReductionOpInfo): + """Tests that dim=[] is a no-op""" + self._test_dim_keepdim(op, device, ndim=0, dim=[]) + self._test_dim_keepdim(op, device, ndim=2, dim=[]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_empty_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=[], when keepdim=True, is a no-op""" + self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True) + self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi(self, device, op: ReductionOpInfo): + """Tests that dim=[i, j, ...] reduces dimensions i, j, ....""" + self._test_dim_keepdim(op, device, ndim=1, dim=[0]) + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1.""" + self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True) + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsorted(self, device, op: ReductionOpInfo): + """Tests that operator correctly handles unsorted dim list.""" + self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo): + """Tests that operator correctly handles unsorted dim list when keepdim=True.""" + self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_duplicate(self, device, op: ReductionOpInfo): + """Tests that an error is raised if dim has duplicate entries.""" + with self.assertRaises(RuntimeError): + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2]) + + @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsupported(self, device, op: ReductionOpInfo): + """Tests that ops claiming to not support multi dim actually don't.""" + with self.assertRaises(TypeError): + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_offbounds(self, device, op: ReductionOpInfo): + """Tests that passing an off-bounds dim throws""" + with self.assertRaises(IndexError): + self._test_dim_keepdim(op, device, ndim=2, dim=2) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_ndim_limit(self, device, op: ReductionOpInfo): + """Tests that an exception is raised when reducing a tensor with more + than 64 dims along some specific dimensions. dim=None is ok""" + t = make_tensor([1] * 65, device, torch.float) + with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): + op(t, dim=0) + + @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported) + def test_identity(self, device, dtype, op: ReductionOpInfo): + """Tests that the identity value is an identity for the operator""" + t = make_tensor((10,), device, dtype) + t[1::2] = op.identity + args, kwargs = next(op.generate_args_kwargs(t)) + result = op(t[::2], *args, **kwargs) + result_with_identity = op(t, *args, **kwargs) + self.assertEqual(result, result_with_identity, """ + Adding identity value to the input tensor should not change the result. + """) + + # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once + # it is added to reduction operators. + + @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported, + allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) + def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo): + """Tests that nan is propagated to the output by default""" + t = make_tensor((5,), device, dtype) + t[2] = torch.nan + args, kwargs = next(op.generate_args_kwargs(t)) + result = op(t, *args, **kwargs) + self.assertTrue(result.isnan()) + + @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported, + allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) + def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo): + """Tests that NaN values do not affect the result.""" + t = make_tensor((10,), device, dtype) + t[1::2] = torch.nan + args, kwargs = next(op.generate_args_kwargs(t)) + result = op(t[::2], *args, **kwargs) + result_with_nan = op(t, *args, **kwargs) + self.assertEqual(result, result_with_nan) + + @ops(reduction_ops, dtypes=OpDTypes.supported) + def test_result_dtype(self, device, dtype, op: ReductionOpInfo): + """Tests that the result has the correct dtype""" + t = make_tensor((5,), device, dtype) + args, kwargs = next(op.generate_args_kwargs(t)) + result: torch.Tensor = op(t, *args, **kwargs) + is_integral = dtype in integral_types_and(torch.bool) + if op.promotes_int_to_float and is_integral: + self.assertTrue(torch.is_floating_point(result.dtype)) + elif op.promotes_int_to_int64 and is_integral: + self.assertEqual(result.dtype, torch.int64) + elif op.result_dtype is not None: + self.assertEqual(result.dtype, op.result_dtype) + else: + self.assertEqual(result.dtype, dtype) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo): + """Tests for consistent behavior when reducing over an empty slice. + + The rules for reducing over an empty slice are as follows: + - Return the identity value if the operator has one + - Otherwise, return NaN if the operator promotes integral dtype to + floating point dtypes. + - Otherwise, raise an error + + See discussion here https://github.com/pytorch/pytorch/issues/61901 + """ + t = make_tensor((0, 2, 3), device, torch.float) + for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []: + args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) + if op.identity is not None: + # Reducing along empty slice should return identity + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result, torch.full_like(result, op.identity)) + elif op.promotes_int_to_float: + # Reducing along empty slice should return NaN + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result, torch.full_like(result, torch.nan)) + else: + # Reducing along empty slice should raise an error + with self.assertRaises(IndexError): + op(t, *args, dim=dim, **kwargs) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo): + """Tests that reducing a nonempty slice of an empty tensor returns an + empty tensor with the dimensions reduced.""" + t = make_tensor((0, 2, 3), device, torch.float) + for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []: + args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result.shape, _reduced_shape(t.shape, dim)) + + ########################################################################### + # TODO: Legacy tests - port to ReductionOpInfo + ########################################################################### + def test_var_unbiased(self, device): tensor = torch.randn(100, device=device) self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) @@ -380,7 +628,7 @@ def _test_out(dtype, other_dtype): # 'out' is favored over dtype, check error self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) - for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: + for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]: x = torch.ones(shape, dtype=dtype) expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64 self.assertIs(expected_dtype, fn(x).dtype) @@ -1011,7 +1259,24 @@ def test_output_dtype(dtype, is_int32): test_output_dtype(torch.int32, False) test_output_dtype(torch.int64, True) - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_complex=False)) + # scalar type bfloat16 + if self.device_type == 'cpu': + def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False): + values_1d_float = values_1d.to(torch.float32) + boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32) + if values_bf16: + values_1d_float = values_1d_float.to(torch.bfloat16) + if boundaries_bf16: + boundaries = boundaries.to(torch.bfloat16) + expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32) + self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result) + self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result) + + test_dtype_bfloat16(True, False) + test_dtype_bfloat16(False, True) + test_dtype_bfloat16(True, True) + + @dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) def test_nansum(self, device, dtype): args = product( (True, False), # noncontiguous @@ -1064,15 +1329,15 @@ def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None, atol=atol, rtol=rtol, exact_dtype=exact_dtype) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + + get_all_complex_dtypes())) def test_count_nonzero(self, device, dtype): self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype) self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True) def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False): def is_integral(dtype): - return dtype in torch.testing.get_all_int_dtypes() + return dtype in get_all_int_dtypes() # On Windows CI, the current version of `numpy` promotes all lower integers # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking @@ -1101,27 +1366,27 @@ def is_integral(dtype): with_keepdim=with_keepdim, with_extremal=with_extremal) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) def test_sum_vs_numpy(self, device, dtype): self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype) self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True) self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) def test_nansum_vs_numpy(self, device, dtype): self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype) self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True) self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True) - @dtypes(*(torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_complex_dtypes())) def test_nansum_complex(self, device, dtype): x = torch.randn((3, 3, 3), device=device, dtype=dtype) with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"): torch.nansum(x) def test_nansum_out_dtype(self, device): - dtypes = list(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)) + dtypes = list(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)) for inp_dtype, out_dtype in combinations(dtypes, 2): shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10) x = _generate_input(shape, inp_dtype, device, with_extremal=False) @@ -1130,7 +1395,7 @@ def test_nansum_out_dtype(self, device): np_fn = partial(np.nansum, dtype=np_out_dtype) self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) def test_argminmax_multiple(self, device, dtype): # Case: All Ones t = torch.ones(3, 3, device=device, dtype=dtype) @@ -1138,7 +1403,7 @@ def test_argminmax_multiple(self, device, dtype): self.compare_with_numpy(torch.argmin, np.argmin, t) # Case: With single `nan` present. - if dtype in torch.testing.get_all_fp_dtypes(): + if dtype in get_all_fp_dtypes(): t[2, 2] = float('nan') self.compare_with_numpy(torch.argmax, np.argmax, t) self.compare_with_numpy(torch.argmin, np.argmin, t) @@ -1215,8 +1480,8 @@ def verify_against_numpy(t): [0, 0]], device=device, dtype=dtype) verify_against_numpy(t) - @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, - include_bool=True, include_complex=True))) + @dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True))) def test_all_any_vs_numpy(self, device, dtype): # Note [all, any uint8 compatibility]: However for compatibility reason, # for `uint8`, they return Tensor of same dtype `uint8`. @@ -1444,7 +1709,7 @@ def test_minmax_illegal_dtype(self, device): with self.assertRaisesRegex(RuntimeError, rmsg): torch.min(x, dim=0, out=(illegal_values, illegal_indices)) - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_complex=False)) + @dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) def test_dim_arg_reduction_scalar(self, device, dtype): example = 4.0 @@ -1462,7 +1727,7 @@ def test_dim_arg_reduction_scalar(self, device, dtype): @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}) - @dtypes(*(set(torch.testing.get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8})) + @dtypes(*(set(get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8})) def test_dim_reduction(self, device, dtype): example = [[-1, 2, 1], [5, 3, 6]] @@ -1771,7 +2036,7 @@ def run_test(input_): run_test(torch.zeros(64, 61, dtype=dtype, device=device)) run_test(torch.zeros(64, 1, dtype=dtype, device=device)) - @slowTest + @onlyCUDA def test_argminmax_large_axis(self, device): # Regression test for gh-32863 x = torch.zeros(2**31, device=device, dtype=torch.int8) @@ -2664,36 +2929,38 @@ def test_tensor_reduce_ops_empty(self, device): self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False) - self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True), msg=error_msg) + self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True), + msg=error_msg) self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True), msg=error_msg, exact_dtype=False) - self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True), msg=error_msg) + self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True), + msg=error_msg) self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True), msg=error_msg, exact_dtype=False) - # Check if returned data is correct. - check_func = (torch.testing.assert_allclose if math.isnan(return_value) or math.isinf(return_value) else - self.assertEqual) - - check_func(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg) - check_func(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg) - check_func(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True), msg=error_msg) - check_func(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True), msg=error_msg) + self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg) + self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg) + self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True), + msg=error_msg) + self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True), + msg=error_msg) if name != 'logsumexp': # The scipy function does not work for reduction the zero dimension - check_func(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(), msg=error_msg) - check_func(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(), msg=error_msg) - check_func(np.float32(np_function(np_input, axis=1, keepdims=True)), - fn(master_input, dim=1, keepdim=True).cpu().numpy(), - msg=error_msg) - check_func(np.float32(np_function(np_input, axis=-2, keepdims=True)), - fn(master_input, dim=-2, keepdim=True).cpu().numpy(), - msg=error_msg) + self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(), + msg=error_msg) + self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(), + msg=error_msg) + self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)), + fn(master_input, dim=1, keepdim=True).cpu().numpy(), + msg=error_msg) + self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)), + fn(master_input, dim=-2, keepdim=True).cpu().numpy(), + msg=error_msg) # logsumexp throws a type error when not specifying dim so test separately. - check_func(torch.full((), return_value, device=device), fn(master_input), msg=error_msg) + self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg) else: self.assertRaises(TypeError, lambda: fn(master_input)) @@ -2704,8 +2971,8 @@ def test_reduction_empty_any_all(self, device): shape = (2, 0, 4) x = torch.randn(shape, device=device) - for dtype in torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, - include_bool=True, include_complex=True): + for dtype in get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True): # Refer: [all, any uint8 compatibility] if dtype == torch.uint8: out_dtype = torch.uint8 diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 916adee666307..3f8c760264709 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -7,11 +7,13 @@ import warnings from torch._six import nan +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict) + TestCase, run_tests, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA, dtypesIfCPU, dtypesIfCUDA, largeTensorTest) +from torch.testing._internal.common_dtype import get_all_dtypes # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -223,9 +225,9 @@ def test_diagonal_multidim(self, device, dtype): self.assertEqual(expected, result) @onlyOnCPUAndCUDA - @dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, - include_bfloat16=False)) - @dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) + @dtypesIfCPU(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False, + include_bfloat16=False)) + @dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) def test_trace(self, device, dtype): def test(shape): tensor = make_tensor(shape, device, dtype, low=-9, high=9) @@ -337,7 +339,7 @@ def test_clamp_raises_arg_errors(self, device): with self.assertRaisesRegex(RuntimeError, error_msg): torch.clamp(X) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_flip(self, device, dtype): make_from_data = partial(torch.tensor, device=device, dtype=dtype) make_from_size = partial(make_tensor, device=device, dtype=dtype) @@ -436,7 +438,7 @@ def gen_data(): for dims in test_dims: self.assertEqual(size, list(data.flip(dims).size())) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_flip_errors(self, device, dtype): make_arg = partial(make_tensor, dtype=dtype, device=device) data = make_arg((2, 2, 2)) @@ -454,7 +456,7 @@ def test_flip_errors(self, device, dtype): def _rand_shape(self, dim, min_size, max_size): return tuple(torch.randint(min_size, max_size + 1, (dim,))) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_flip_numpy(self, device, dtype): make_arg = partial(make_tensor, dtype=dtype, device=device) @@ -563,7 +565,7 @@ def test_nonzero_no_warning(self, device): t.nonzero() self.assertEqual(len(w), 0) - @dtypes(*torch.testing.get_all_dtypes(include_complex=False)) + @dtypes(*get_all_dtypes(include_complex=False)) def test_nonzero(self, device, dtype): shapes = [ diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 564258aa77b51..52c32952a6965 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -5,9 +5,12 @@ from torch._six import nan from itertools import permutations, product -from torch.testing import all_types, all_types_and +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + all_types, all_types_and, floating_types_and, get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, +) from torch.testing._internal.common_utils import \ - (TEST_WITH_ROCM, TestCase, run_tests, make_tensor, slowTest) + (TEST_WITH_ROCM, TestCase, run_tests, slowTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) @@ -128,7 +131,7 @@ def test_sort(self, device): 'random with NaNs') # FIXME: remove torch.bool from unsupported types once support is added for cub sort - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) + @dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) def test_stable_sort(self, device, dtype): if TEST_WITH_ROCM and dtype == torch.bfloat16: return @@ -223,11 +226,11 @@ def test_topk_1d_output_discontiguous(self, device, dtype): self.assertEqual(values, values_cont) # FIXME: remove torch.bool from unsupported types once support is added for cub sort - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) + @dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) def test_stable_sort_against_numpy(self, device, dtype): if TEST_WITH_ROCM and dtype == torch.bfloat16: return - if dtype in torch.testing.floating_types_and(torch.float16, torch.bfloat16): + if dtype in floating_types_and(torch.float16, torch.bfloat16): inf = float('inf') neg_inf = -float('inf') nan = float('nan') @@ -288,7 +291,7 @@ def repeated_index_fill(t, dim, idxs, vals): idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') self.assertEqual(idx_torch, idx_numpy) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_msort(self, device, dtype): if TEST_WITH_ROCM and dtype == torch.bfloat16: return @@ -634,7 +637,7 @@ def test_topk_bfloat16(self, device, dtype): for curr_size in (small, large): self._test_topk_dtype(device, dtype, False, curr_size) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_fp_dtypes()) @dtypes(torch.float, torch.double, torch.bfloat16) def test_topk_nonfinite(self, device, dtype): if TEST_WITH_ROCM and dtype == torch.bfloat16: @@ -665,11 +668,11 @@ def test_topk_4d(self, device): self.assertEqual(ind, expected_ind, atol=0, rtol=0) @onlyOnCPUAndCUDA - @dtypesIfCUDA(*(torch.testing.get_all_dtypes(include_complex=False, - include_bool=False, - include_half=False, - include_bfloat16=True))) - @dtypes(*(torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))) + @dtypesIfCUDA(*(get_all_dtypes(include_complex=False, + include_bool=False, + include_half=False, + include_bfloat16=True))) + @dtypes(*(get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))) def test_topk_zero(self, device, dtype): if TEST_WITH_ROCM and dtype == torch.bfloat16: return @@ -726,8 +729,8 @@ def ensure_tuple(x): self.assertEqual(expected_inverse.view(additional_shape), y_inverse) self.assertEqual(expected_counts, y_counts) - @dtypesIfCPU(*set(torch.testing.get_all_dtypes()) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) + @dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) + @dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) def test_unique(self, device, dtype): if dtype is torch.half and self.device_type == 'cpu': return # CPU does not have half support @@ -786,8 +789,8 @@ def ensure_tuple(x): count += 1 self.assertEqual(j, count) - @dtypesIfCPU(*set(torch.testing.get_all_dtypes()) - {torch.complex64, torch.complex128}) - @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) + @dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) + @dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) def test_unique_consecutive(self, device, dtype): if dtype is torch.half and self.device_type == 'cpu': return # CPU does not have half support diff --git a/test/test_sparse.py b/test/test_sparse.py index abe5e93889498..f9ed0dc11ffbd 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -5,8 +5,9 @@ import random from collections import defaultdict import unittest +from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, make_tensor, \ + do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ DeterministicGuard from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number @@ -15,6 +16,9 @@ (instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, onlyCPU, onlyCUDA, deviceCountAtLeast) from torch.testing._internal.common_methods_invocations import \ (sparse_unary_ufuncs) +from torch.testing._internal.common_dtype import ( + floating_and_complex_types, floating_and_complex_types_and, get_all_dtypes, get_all_int_dtypes, +) if TEST_SCIPY: import scipy.sparse @@ -285,7 +289,7 @@ def test_ctor_size_checks(self, device, dtype): RuntimeError, lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1]))) - @dtypes(*torch.testing.floating_and_complex_types_and(torch.float16)) + @dtypes(*floating_and_complex_types_and(torch.float16)) def test_to_dense(self, device, dtype): def test_tensor(x, res): x.to_dense() # Tests triple to_dense for memory corruption @@ -1557,7 +1561,7 @@ def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, self.assertEqual(self.safeToDense(y1), expected) self.assertEqual(self.safeToDense(y2), expected) - with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): + with self.assertWarnsOnceRegex(UserWarning, '__floordiv__'): y1 = x1 // 37.5 y2 = x1.clone() with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): @@ -1942,7 +1946,7 @@ def test_narrow(self, device, dtype, coalesced): def _test_log1p_tensor(self, sparse_tensor, coalesced): def is_integral(dtype): - return dtype in torch.testing.get_all_int_dtypes() + return dtype in get_all_int_dtypes() dense_tensor = sparse_tensor.to_dense() expected_output = dense_tensor.log1p() @@ -1976,8 +1980,8 @@ def is_integral(dtype): sparse_tensor.requires_grad_() @coalescedonoff - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, - include_bfloat16=False, include_complex=False)) + @dtypes(*get_all_dtypes(include_bool=False, include_half=False, + include_bfloat16=False, include_complex=False)) def test_log1p(self, device, dtype, coalesced): if coalesced: input_coalesced = torch.sparse_coo_tensor( @@ -2085,7 +2089,7 @@ def test_neg_negative(self, device, dtype, coalesced): def _test_asin_arcsin(self, sparse_tensor, coalesced): def is_integral(dtype): - return dtype in torch.testing.get_all_int_dtypes() + return dtype in get_all_int_dtypes() is_integral_dtype = is_integral(sparse_tensor.dtype) dense_tensor = sparse_tensor.to_dense() @@ -2124,8 +2128,8 @@ def is_integral(dtype): op(sparse_tensor) @coalescedonoff - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, - include_bfloat16=False, include_complex=False)) + @dtypes(*get_all_dtypes(include_bool=False, include_half=False, + include_bfloat16=False, include_complex=False)) def test_asin_arcsin(self, device, dtype, coalesced): if coalesced: input_coalesced = torch.sparse_coo_tensor( @@ -2195,7 +2199,7 @@ def test_shape(di, dj, dk, nnz): y, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced) res = x.mv(y) - @dtypes(*torch.testing.floating_and_complex_types()) + @dtypes(*floating_and_complex_types()) def test_sparse_add_coalesce(self, device, dtype): i = self.index_tensor([[1, 2, 1]], device=device) v = torch.tensor([3, 4, 5], dtype=dtype, device=device) @@ -2612,14 +2616,14 @@ def test_legacy_new(self, device): @onlyCPU # not really, but we only really want to run this once def test_dtypes(self, device): - all_sparse_dtypes = torch.testing.get_all_dtypes(include_complex=True) + all_sparse_dtypes = get_all_dtypes(include_complex=True) do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu')) if torch.cuda.is_available(): do_test_dtypes(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cuda:0')) @onlyCPU # not really, but we only really want to run this once def test_empty_full(self, device): - all_sparse_dtypes = torch.testing.get_all_dtypes(include_complex=True) + all_sparse_dtypes = get_all_dtypes(include_complex=True) do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, torch.device('cpu')) if torch.cuda.device_count() > 0: do_test_empty_full(self, all_sparse_dtypes, torch.sparse_coo, None) @@ -2910,7 +2914,7 @@ def test_div_by_sparse_error(self, device): / torch.tensor(1., device=device).to_sparse()) def test_floor_divide_by_sparse_error(self, device): - self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires', + self.assertRaisesRegex(RuntimeError, 'Sparse division requires', lambda: torch.tensor(1., device=device).to_sparse() // torch.tensor(1., device=device).to_sparse()) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index b9f48855e46db..af99fa031fca3 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3,10 +3,13 @@ import unittest import random import itertools + +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, make_tensor) + (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU, onlyCUDA) +from torch.testing._internal.common_dtype import floating_types, get_all_dtypes # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -41,7 +44,7 @@ def test_csr_layout(self): self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr') self.assertEqual(type(torch.sparse_csr), torch.layout) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_constructor_shape_inference(self, device, dtype): crow_indices = [0, 2, 4] col_indices = [0, 1, 0, 1] @@ -54,7 +57,7 @@ def test_sparse_csr_constructor_shape_inference(self, device, dtype): self.assertEqual(dtype, sparse.dtype) self.assertEqual(torch.device(device), sparse.device) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_constructor(self, device, dtype): crow_indices = [0, 2, 4] col_indices = [0, 1, 0, 1] @@ -71,7 +74,7 @@ def test_sparse_csr_constructor(self, device, dtype): self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices()) self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values()) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_constructor_from_lists(self, device, dtype): # without size sparse = torch.sparse_csr_tensor([0, 2, 4], @@ -207,7 +210,7 @@ def test_factory_indices_invariants_check(self, device): device=device) @onlyCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_factory_device_type_inference(self, device, dtype): cpu_cuda = ('cpu', 'cuda') cpu_cuda_none = cpu_cuda + (None,) @@ -260,7 +263,7 @@ def test_sparse_csr_print(self, device): printed.append("# col_indices shape: {}".format(col_indices_shape)) printed.append("# values_shape: {}".format(values_shape)) for index_dtype in [torch.int32, torch.int64]: - for dtype in torch.testing.floating_types(): + for dtype in floating_types(): printed.append("########## {}/{} ##########".format(dtype, index_dtype)) x = torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=index_dtype), torch.tensor([0, 1, 0, 1], dtype=index_dtype), @@ -278,7 +281,7 @@ def test_sparse_csr_print(self, device): self.assertExpected('\n'.join(printed)) self.maxDiff = orig_maxDiff - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_from_dense(self, device, dtype): dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device) sparse = dense.to_sparse_csr() @@ -298,7 +301,7 @@ def test_sparse_csr_from_dense(self, device, dtype): self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices()) self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values()) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_to_dense(self, device, dtype): mn = [5, 2, 0] for (m, n) in itertools.product(mn, mn): @@ -375,7 +378,7 @@ def test_mkl_matvec_warnings(self, device, dtype): self.assertIn("Pytorch is compiled with MKL LP64 and will convert col_indices to int32", str(w[1].message)) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_sparse_csr_from_dense_convert_error(self, device, dtype): size = (4, 2, 4) dense = make_tensor(size, dtype=dtype, device=device) @@ -443,7 +446,7 @@ def test_shape(di, dj, dk, nnz): test_shape(i, j, k, i * j // 2) test_shape(4, 4, 4, 0) - @dtypes(*torch.testing.floating_types()) + @dtypes(*floating_types()) def test_sparse_mm(self, device, dtype): def test_shape(d1, d2, d3, nnz, transposed): if transposed: @@ -457,7 +460,7 @@ def test_shape(d1, d2, d3, nnz, transposed): test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) - @dtypes(*torch.testing.floating_types()) + @dtypes(*floating_types()) def test_sparse_addmm(self, device, dtype): def test_shape(m, n, p, nnz, broadcast, alpha_beta=None): if alpha_beta is None: @@ -512,7 +515,7 @@ def _test_spadd_shape(nnz, shape): _test_spadd_shape(10, [100, 1]) _test_spadd_shape(10, [1, 100]) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_coo_csr_conversion(self, device, dtype): for m, n in itertools.product([5, 2, 0], [5, 2, 0]): size = (m, n) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index e7e4832ad5631..f632e95d9c704 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -658,7 +658,7 @@ def test_fftshift_frequencies(self, device, dtype): # Test fftshift sorts the fftfreq output shifted = torch.fft.fftshift(x) - self.assertTrue(torch.allclose(shifted, shifted.sort().values)) + self.assertEqual(shifted, shifted.sort().values) self.assertEqual(sorted_fft_freqs, shifted) # And ifftshift is the inverse @@ -1126,9 +1126,6 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) - - # trim the original for case when constructed signal is shorter than original - original = original[..., :inversed.size(-1)] self.assertEqual( inversed, original, msg='istft comparison against original', atol=7e-6, rtol=0, exact_dtype=True) @@ -1167,21 +1164,63 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): 'normalized': True, 'onesided': False, }, - # hamming_window, not centered, not normalized, onesided + # hamming_window, centered, not normalized, onesided # window same size as n_fft { 'n_fft': 5, 'hop_length': 2, 'win_length': 5, 'window': torch.hamming_window(5, dtype=dtype, device=device), - 'center': False, + 'center': True, 'pad_mode': 'constant', 'normalized': False, 'onesided': True, }, + ] + for i, pattern in enumerate(patterns): + _test_istft_is_inverse_of_stft(pattern) + + @onlyOnCPUAndCUDA + @skipCPUIfNoFFT + @dtypes(torch.double) + def test_istft_round_trip_with_padding(self, device, dtype): + """long hop_length or not centered may cause length mismatch in the inversed signal""" + def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): + # generates a random sound signal for each tril and then does the stft/istft + # operation to check whether we can reconstruct signal + num_trials = 100 + sizes = stft_kwargs['size'] + del stft_kwargs['size'] + istft_kwargs = stft_kwargs.copy() + del istft_kwargs['pad_mode'] + for i in range(num_trials): + original = torch.randn(*sizes, dtype=dtype, device=device) + stft = torch.stft(original, return_complex=True, **stft_kwargs) + with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): + inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) + n_frames = stft.size(-1) + if stft_kwargs["center"] is True: + len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) + else: + len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) + # trim the original for case when constructed signal is shorter than original + padding = inversed[..., len_expected:] + inversed = inversed[..., :len_expected] + original = original[..., :len_expected] + # test the padding points of the inversed signal are all zeros + zeros = torch.zeros_like(padding, device=padding.device) + self.assertEqual( + padding, zeros, msg='istft padding values against zeros', + atol=7e-6, rtol=0, exact_dtype=True) + self.assertEqual( + inversed, original, msg='istft comparison against original', + atol=7e-6, rtol=0, exact_dtype=True) + + patterns = [ # hamming_window, not centered, not normalized, not onesided # window same size as n_fft { + 'size': [2, 20], 'n_fft': 3, 'hop_length': 2, 'win_length': 3, @@ -1191,9 +1230,22 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): 'normalized': False, 'onesided': False, }, + # hamming_window, centered, not normalized, onesided, long hop_length + # window same size as n_fft + { + 'size': [2, 500], + 'n_fft': 256, + 'hop_length': 254, + 'win_length': 256, + 'window': torch.hamming_window(256, dtype=dtype, device=device), + 'center': True, + 'pad_mode': 'constant', + 'normalized': False, + 'onesided': True, + }, ] for i, pattern in enumerate(patterns): - _test_istft_is_inverse_of_stft(pattern) + _test_istft_is_inverse_of_stft_with_padding(pattern) @onlyOnCPUAndCUDA def test_istft_throws(self, device): diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 9b38a5a7e36a8..94043e2745626 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -186,10 +186,10 @@ def test_multihead_attention_layer(self): o_test_kw = attention_a(src, src, value=src, mask=src_mask) for a, b in zip(o_ref, o_test): - torch.testing.assert_allclose(a, b) + torch.testing.assert_close(a, b) for a, b in zip(o_ref, o_test_kw): - torch.testing.assert_allclose(a, b) + torch.testing.assert_close(a, b) def test_multihead_attention_layer_benchmark(self): HID_DIM = 256 @@ -228,20 +228,20 @@ def test_mlp(self): top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp)[0] - torch.testing.assert_allclose(acc_bot, ref_bot) + torch.testing.assert_close(acc_bot, ref_bot) ref_top = top_l(top_inp) acc_top = top_l_acc(top_inp)[0] - torch.testing.assert_allclose(acc_top, ref_top) + torch.testing.assert_close(acc_top, ref_top) for _ in range(5): with torch.no_grad(): bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) ref_bot = bot_l(bot_inp) acc_bot = bot_l_acc(bot_inp)[0] - torch.testing.assert_allclose(acc_bot, ref_bot) + torch.testing.assert_close(acc_bot, ref_bot) ref_top = top_l(top_inp) acc_top = top_l_acc(top_inp)[0] - torch.testing.assert_allclose(acc_top, ref_top) + torch.testing.assert_close(acc_top, ref_top) def test_trivial_graph(self): s = torch.full((2, 2), 2) @@ -249,7 +249,7 @@ def test_trivial_graph(self): o_ref = tg(s, s, s) tg_a = StaticModule(tg) o_test = tg_a(s, s, s)[0] - torch.testing.assert_allclose(o_ref, o_test) + torch.testing.assert_close(o_ref, o_test) def test_leaky_relu(self): s = torch.randn(5, 5) @@ -257,7 +257,7 @@ def test_leaky_relu(self): o_ref = tg(s) tg_a = StaticModule(tg) o_test = tg_a(s)[0] - torch.testing.assert_allclose(o_ref, o_test) + torch.testing.assert_close(o_ref, o_test) def test_attr(self): """ @@ -293,7 +293,7 @@ def test_attr(self): ms = torch.jit.script(m) sm = StaticModule(ms) output_sm = sm(input)[0] - torch.testing.assert_allclose(output_s, output_sm) + torch.testing.assert_close(output_s, output_sm) sm.benchmark([input], {}, 2, 2) sm.benchmark_individual_ops([input], {}, 2, 2) sm.benchmark([], {"x": input}, 2, 2) @@ -307,7 +307,7 @@ def test_fusion_trivial_graph(self): torch._C._fuse_to_static_module(tg.graph) assert "StaticSubgraph" in str(tg.graph) o_test = tg(s, s, s) - torch.testing.assert_allclose(o_ref, o_test) + torch.testing.assert_close(o_ref, o_test) @unittest.skip("Temporarily disabled") def test_fusion_multihead_attention_layer(self): @@ -332,7 +332,7 @@ def test_fusion_multihead_attention_layer(self): o_test = attention(src, src, src, src_mask) for a, b in zip(o_ref, o_test): - torch.testing.assert_allclose(a, b) + torch.testing.assert_close(a, b) @unittest.skip("Temporarily disabled") def test_fusion_loop(self): @@ -344,7 +344,7 @@ def test_fusion_loop(self): torch._C._fuse_to_static_module(lg.graph) assert "StaticSubgraph" in str(lg.graph) o_test = lg(a, b, c) - torch.testing.assert_allclose(o_ref, o_test) + torch.testing.assert_close(o_ref, o_test) @unittest.skip("Temporarily disabled") def test_fusion_outputs(self): @@ -357,7 +357,7 @@ def test_fusion_outputs(self): assert "StaticSubgraph" in str(og.graph) o_test = og(a, b, b, c) for i in o_ref.keys(): - torch.testing.assert_allclose(o_ref[i], o_test[i]) + torch.testing.assert_close(o_ref[i], o_test[i]) if __name__ == "__main__": diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 192e03f61cac0..4a2216d230203 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -8,14 +8,18 @@ from itertools import product, combinations, combinations_with_replacement, permutations import random +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, - torch_to_numpy_dtype_dict, slowTest, make_tensor, TEST_SCIPY, IS_MACOS, IS_PPC, + torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, onlyCPU, largeTensorTest, precisionOverride, dtypes, onlyCUDA, skipCPUIf, dtypesIfCUDA, dtypesIfCPU, skipMeta) +from torch.testing._internal.common_dtype import ( + get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes +) # TODO: refactor tri_tests_args, _compare_trilu_indices, run_additional_tri_tests from torch.testing._internal.common_methods_invocations import ( @@ -139,7 +143,7 @@ def test_vander_types(self, device, dtype): exact_dtype=False) def test_cat_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) @@ -149,7 +153,7 @@ def test_cat_all_dtypes_and_devices(self, device): self.assertEqual(torch.cat((x, x), 1), expected2) def test_fill_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): for x in [torch.tensor((10, 10), dtype=dt, device=device), torch.empty(10000, dtype=dt, device=device)]: # large tensor numel = x.numel() @@ -303,7 +307,7 @@ def run_test(shape, device, diagonal, dtype): (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices (1, 3), (5, 1, 3), (7, 5, 1, 3), # very thin matrices (1, 3, 3, 3), (3, 1, 3, 3, 3)] # unsqueezed batch dimensions - dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.bfloat16] + dtypes = [dtype for dtype in get_all_dtypes() if dtype != torch.bfloat16] for s, d, dtype in product(shapes, diagonals, dtypes): run_test(s, device, d, dtype) @@ -694,6 +698,47 @@ def test_cat_preserve_channels_last(self, device): self.assertEqual(res1, res2) self.assertTrue(res1.is_contiguous(memory_format=torch.channels_last)) + @onlyCUDA + def test_cat_out_memory_format(self, device): + inp_size = (4, 4, 4, 4) + expected_size = (8, 4, 4, 4) + a_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last) + a_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.channels_last) + b_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.contiguous_format) + b_cpu = torch.randn(inp_size, device='cpu').contiguous(memory_format=torch.contiguous_format) + c_cuda = torch.randn(inp_size, device=device).contiguous(memory_format=torch.channels_last) + + # Case 1: if out= is the correct shape then the memory format of out= is respected + + out_cuda = torch.empty(expected_size, device=device).contiguous(memory_format=torch.contiguous_format) + res1_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda) + + out_cpu = torch.empty(expected_size, device='cpu').contiguous(memory_format=torch.contiguous_format) + res1_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu) + + self.assertTrue(res1_cuda.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue(res1_cpu.is_contiguous(memory_format=torch.contiguous_format)) + + # Case 2: if out= is not the correct shape then the output it is resized internally + # - For the CPU variant the memory format is that of the first tensor + # - For the CUDA variant it only propagates memory format if all the tensors have + # the same memory format, otherwise it just uses contiguous_format as a default + + out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format) + # a_cuda and b_cuda have different memory_format + res2_cuda = torch.cat((a_cuda, b_cuda), out=out_cuda) + + out_cpu = torch.empty((0), device='cpu').contiguous(memory_format=torch.contiguous_format) + res2_cpu = torch.cat((a_cpu, b_cpu), out=out_cpu) + + self.assertTrue(res2_cuda.is_contiguous(memory_format=torch.contiguous_format)) + self.assertTrue(res2_cpu.is_contiguous(memory_format=torch.channels_last)) + + out_cuda = torch.empty((0), device=device).contiguous(memory_format=torch.contiguous_format) + # a_cuda and c_cuda have same memory_format + res3_cuda = torch.cat((a_cuda, c_cuda), out=out_cuda) + + self.assertTrue(res3_cuda.is_contiguous(memory_format=torch.channels_last)) @onlyCUDA @deviceCountAtLeast(2) @@ -712,8 +757,8 @@ def test_cat_different_devices(self, devices): def test_cat_stack_cross_devices(self, device): cuda = torch.randn((3, 3), device=device) cpu = torch.randn((3, 3), device='cpu') - out_cpu = cpu.clone() - out_cuda = cuda.clone() + + # cat with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): torch.cat((cuda, cpu)) @@ -721,18 +766,6 @@ def test_cat_stack_cross_devices(self, device): "Expected all tensors to be on the same device"): torch.cat((cpu, cuda)) - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.cat((cpu, cuda), out=out_cuda) - - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.cat((cpu, cpu), out=out_cuda) - - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.cat((cuda, cuda), out=out_cpu) - # Stack with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): @@ -741,18 +774,6 @@ def test_cat_stack_cross_devices(self, device): "Expected all tensors to be on the same device"): torch.stack((cpu, cuda)) - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.stack((cpu, cuda), out=out_cuda) - - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.stack((cpu, cpu), out=out_cuda) - - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): - torch.stack((cuda, cuda), out=out_cpu) - # TODO: reconcile with other cat tests # TODO: Compare with a NumPy reference instead of CPU @onlyCUDA @@ -969,8 +990,8 @@ def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype np_fn(np_input) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + + get_all_complex_dtypes())) def test_hstack_column_stack(self, device, dtype): ops = ((torch.hstack, np.hstack), (torch.column_stack, np.column_stack)) for torch_op, np_op in ops: @@ -989,8 +1010,8 @@ def test_hstack_column_stack(self, device, dtype): torch_result) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + + get_all_complex_dtypes())) def test_vstack_row_stack(self, device, dtype): ops = ((torch.vstack, np.vstack), (torch.row_stack, np.row_stack)) for torch_op, np_op in ops: @@ -1007,8 +1028,8 @@ def test_vstack_row_stack(self, device, dtype): self.assertEqual(actual, expected) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + + get_all_complex_dtypes())) def test_dstack(self, device, dtype): self._test_special_stacks(2, 3, torch.dstack, np.dstack, device, dtype) for i in range(5): @@ -1554,7 +1575,7 @@ def test_random_from_to_bool(self, device): lambda: t.random_(from_, to_) ) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_random_full_range(self, device, dtype): size = 2000 alpha = 0.1 @@ -1588,7 +1609,7 @@ def test_random_full_range(self, device, dtype): self.assertTrue(from_ <= t.to(torch.double).min() < (from_ + delta)) self.assertTrue((to_inc_ - delta) < t.to(torch.double).max() <= to_inc_) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_random_from_to(self, device, dtype): size = 2000 alpha = 0.1 @@ -1677,7 +1698,7 @@ def test_random_from_to(self, device, dtype): lambda: t.random_(from_, to_) ) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_random_to(self, device, dtype): size = 2000 alpha = 0.1 @@ -1735,7 +1756,7 @@ def test_random_to(self, device, dtype): lambda: t.random_(from_, to_) ) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_random_default(self, device, dtype): size = 2000 alpha = 0.1 @@ -1763,10 +1784,10 @@ def test_empty_full(self, device): device_type = torch_device.type if device_type == 'cpu': - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch_device) + do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) if device_type == 'cuda': - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, None) - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch_device) + do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, None) + do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) # TODO: this test should be updated @suppress_warnings @@ -2454,7 +2475,7 @@ def test_empty_tensor_props(self, device): self.assertEqual(x.stride(), y.stride()) def test_eye(self, device): - for dtype in torch.testing.get_all_dtypes(): + for dtype in get_all_dtypes(): if dtype == torch.bfloat16: continue # Test the RuntimeError is raised when either m or n is a negative number @@ -2487,8 +2508,8 @@ def test_eye(self, device): self.assertEqual(res1, res2) @precisionOverride({torch.float: 1e-8, torch.double: 1e-10}) - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False) + + get_all_complex_dtypes())) def test_linspace_vs_numpy(self, device, dtype): start = -0.0316082797944545745849609375 + (0.8888888888j if dtype.is_complex else 0) end = .0315315723419189453125 + (0.444444444444j if dtype.is_complex else 0) @@ -2525,7 +2546,7 @@ def test_logspace_vs_numpy_complex(self, device, dtype): device, dtype) @precisionOverride({torch.float: 1e-6, torch.double: 1e-10}) - @dtypes(*torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False)) + @dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) def test_logspace_vs_numpy(self, device, dtype): start = -0.0316082797944545745849609375 end = .0315315723419189453125 @@ -2635,7 +2656,7 @@ def test_tensor_factories_empty(self, device): shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)] for shape in shapes: - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape) self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) @@ -2721,8 +2742,8 @@ def test_arange_bfloat16(self, device): bfloat16_tensor = torch.arange(0, 6, step=2, dtype=torch.bfloat16, device=device) self.assertEqual(ref_tensor, bfloat16_tensor) - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False)) - @dtypesIfCUDA(*torch.testing.get_all_dtypes(include_bool=False, include_half=True)) + @dtypes(*get_all_dtypes(include_bool=False, include_half=False)) + @dtypesIfCUDA(*get_all_dtypes(include_bool=False, include_half=True)) def test_linspace(self, device, dtype): _from = random.random() to = _from + random.random() @@ -2836,12 +2857,12 @@ def _test_linspace(self, device, dtype, steps): # See NOTE [Linspace+Logspace precision override] @skipCPUIf(True, "compares with CPU") @precisionOverride({torch.half: 0.0039 + LINSPACE_LOGSPACE_EXTRA_EPS}) - @dtypes(*(torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) def test_linspace_device_vs_cpu(self, device, dtype): self._test_linspace(device, dtype, steps=10) @skipCPUIf(True, "compares with CPU") - @dtypes(*(torch.testing.get_all_fp_dtypes() + torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) def test_linspace_special_steps(self, device, dtype): for steps in self.LINSPACE_LOGSPACE_SPECIAL_STEPS: self._test_linspace(device, dtype, steps=steps) @@ -2882,10 +2903,10 @@ def test_logspace_special_steps(self, device, dtype): self._test_logspace(device, dtype, steps=steps) self._test_logspace_base2(device, dtype, steps=steps) - @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, include_complex=False)) - @dtypesIfCUDA(*((torch.testing.get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16]) + @dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_complex=False)) + @dtypesIfCUDA(*((get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16]) if TEST_WITH_ROCM - else torch.testing.get_all_dtypes(include_bool=False, include_half=True, include_complex=False))) + else get_all_dtypes(include_bool=False, include_half=True, include_complex=False))) def test_logspace(self, device, dtype): _from = random.random() to = _from + random.random() @@ -3257,7 +3278,7 @@ def seed(generator): self.assertTrue((res1 >= 0).all().item()) @dtypes(torch.half, torch.float, torch.bfloat16, torch.double, - torch.complex32, torch.complex64, torch.complex128) + torch.complex64, torch.complex128) def test_randn(self, device, dtype): SIZE = 100 for size in [0, SIZE]: diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 50145100abf8f..366c262ad7c1d 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -24,9 +24,6 @@ def setUp(self): torch._C._debug_set_fusion_group_inlining(False) self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() torch._C._jit_set_te_must_use_llvm_cpu(False) - # TODO: CPU fuser currently is disabled when multithreading. - self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled() - torch._C._jit_set_texpr_parallel_cpu_enabled(True) self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] @@ -39,7 +36,6 @@ def tearDown(self): torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) - torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel) def assertLastGraphAllFused(self): self.assertAllFused(torch.jit.last_executed_optimized_graph()) @@ -1226,7 +1222,6 @@ def bias_gelu(bias, y): x = warmup_and_run_forward(traced, a, b) self.assertLastGraphAllFused() - @unittest.skip("float16 is not supported yet.") def test_half_bn_relu(self): devices = ["cuda"] if torch.cuda.is_available() else [] @@ -1468,7 +1463,7 @@ def getModule(script): am_s = getModule(True) ref = am(x, x, x) test = am_s(x, x, x) - torch.testing.assert_allclose(ref, test) + torch.testing.assert_close(ref, test) # Now do the aliasing am.a = am.b @@ -1477,7 +1472,7 @@ def getModule(script): am_s.a = am_s.b test = am_s(x, x, x) - torch.testing.assert_allclose(ref, test) + torch.testing.assert_close(ref, test) def test_alias_analysis_inputs(self): class AliasModule(nn.Module): @@ -1510,7 +1505,7 @@ def getModule(script): x = torch.randn(128, 128) test = am_s(x, x, x) - torch.testing.assert_allclose(ref, test) + torch.testing.assert_close(ref, test) def test_alias_analysis_input_and_module(self): class AliasModule(nn.Module): @@ -1545,7 +1540,7 @@ def getModule(script): am_s.b = x test = am_s(x, x, x) - torch.testing.assert_allclose(ref, test) + torch.testing.assert_close(ref, test) def test_multiple_outputs(self): for device in self.devices: diff --git a/test/test_tensorexpr_pybind.py b/test/test_tensorexpr_pybind.py index 4138b2f81dfda..6a348053c01fd 100644 --- a/test/test_tensorexpr_pybind.py +++ b/test/test_tensorexpr_pybind.py @@ -9,14 +9,6 @@ LLVM_ENABLED = torch._C._llvm_enabled() -class kernel_arena_scope(object): - def __enter__(self): - self.scope = torch._C._te.KernelScope() - - def __exit__(self, typ, val, traceback): - self.scope = None - - def construct_adder(n: int, dtype=te.Dtype.Float): dN = te.ExprHandle.int(n) A = te.Placeholder('A', dtype, [dN]) @@ -36,85 +28,80 @@ def compute(i): class TestTensorExprPyBind(JitTestCase): def test_simple_sum(self): - with kernel_arena_scope(): - n = 32 - cg = construct_adder(n) + n = 32 + cg = construct_adder(n) - tA = torch.randn(n) - tB = torch.randn(n) - tC = torch.empty(n) - cg.call([tA, tB, tC]) - torch.testing.assert_allclose(tA + tB, tC) + tA = torch.randn(n) + tB = torch.randn(n) + tC = torch.empty(n) + cg.call([tA, tB, tC]) + torch.testing.assert_close(tA + tB, tC) def test_call_raw(self): - with kernel_arena_scope(): - n = 16 - cg = construct_adder(n, dtype=torch.float64) + n = 16 + cg = construct_adder(n, dtype=torch.float64) - tA = torch.randn(n, dtype=torch.float64) - tB = torch.randn(n, dtype=torch.float64) - tC = torch.empty(n, dtype=torch.float64) - cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()]) - torch.testing.assert_allclose(tA + tB, tC) + tA = torch.randn(n, dtype=torch.float64) + tB = torch.randn(n, dtype=torch.float64) + tC = torch.empty(n, dtype=torch.float64) + cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()]) + torch.testing.assert_close(tA + tB, tC) def test_external_calls(self): - with kernel_arena_scope(): - dtype = torch.float32 + dtype = torch.float32 - ONE = te.ExprHandle.int(1) - FOUR = te.ExprHandle.int(4) - A = te.BufHandle('A', [ONE, FOUR], dtype) - B = te.BufHandle('B', [FOUR, ONE], dtype) - C = te.BufHandle('C', [ONE, ONE], dtype) + ONE = te.ExprHandle.int(1) + FOUR = te.ExprHandle.int(4) + A = te.BufHandle('A', [ONE, FOUR], dtype) + B = te.BufHandle('B', [FOUR, ONE], dtype) + C = te.BufHandle('C', [ONE, ONE], dtype) - s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) + s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], []) - loopnest = te.LoopNest(s, [C]) - loopnest.prepare_for_codegen() - codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) + loopnest = te.LoopNest(s, [C]) + loopnest.prepare_for_codegen() + codegen = te.construct_codegen('ir_eval', s, [te.BufferArg(x) for x in [A, B, C]]) - tA = torch.ones(1, 4) - tB = torch.ones(4, 1) - tC = torch.empty(1, 1) - codegen.call([tA, tB, tC]) - torch.testing.assert_allclose(torch.matmul(tA, tB), tC) + tA = torch.ones(1, 4) + tB = torch.ones(4, 1) + tC = torch.empty(1, 1) + codegen.call([tA, tB, tC]) + torch.testing.assert_close(torch.matmul(tA, tB), tC) def test_dynamic_shape(self): - with kernel_arena_scope(): - dN = te.VarHandle(torch.int32) - A = te.BufHandle(torch.float64) - B = te.BufHandle(torch.float64) + dN = te.VarHandle(torch.int32) + A = te.BufHandle(torch.float64) + B = te.BufHandle(torch.float64) - def compute(i): - return A.load(i) - B.load(i) + def compute(i): + return A.load(i) - B.load(i) - C = te.Compute('C', [dN], compute) + C = te.Compute('C', [dN], compute) - loopnest = te.LoopNest([C]) - loopnest.prepare_for_codegen() + loopnest = te.LoopNest([C]) + loopnest.prepare_for_codegen() - cg = te.construct_codegen( - 'ir_eval', - loopnest.simplify(), - [A, B, C, dN]) + cg = te.construct_codegen( + 'ir_eval', + loopnest.simplify(), + [A, B, C, dN]) - def test_with_shape(n): - tA = torch.randn(n, dtype=torch.double) - tB = torch.randn(n, dtype=torch.double) - tC = torch.empty(n, dtype=torch.double) - cg.call([tA, tB, tC, n]) - torch.testing.assert_allclose(tA - tB, tC) + def test_with_shape(n): + tA = torch.randn(n, dtype=torch.double) + tB = torch.randn(n, dtype=torch.double) + tC = torch.empty(n, dtype=torch.double) + cg.call([tA, tB, tC, n]) + torch.testing.assert_close(tA - tB, tC) - test_with_shape(8) - test_with_shape(31) + test_with_shape(8) + test_with_shape(31) def test_dtype_error(self): - with kernel_arena_scope(): - one = te.ExprHandle.int(1) - te.Placeholder([one], torch.float32) # ok - te.Placeholder([one]) # ok - self.assertRaises(TypeError, - lambda: te.Placeholder([one], "float55")) + one = te.ExprHandle.int(1) + te.Placeholder([one], torch.float32) # ok + te.Placeholder([one]) # ok + self.assertRaises(TypeError, + lambda: te.Placeholder([one], "float55")) @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_kernel_with_tensor_inputs(self): @@ -394,28 +381,24 @@ def f(a): np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3) np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3) - def test_forgot_kernel_arena(self): - self.assertRaises(RuntimeError, lambda: torch._C._te.VarHandle("n", torch._C._te.Dtype.Int)) - @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled") def test_alloc_in_loop(self): - with kernel_arena_scope(): - a, tmp, b = [ - te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)]) - for name in ["a", "tmp", "b"]] - t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]] - body = te.Block([ - tmp.store([t0], a.load([t0])), - b.store([t0], tmp.load([t0])) - ]) - for _ in range(4): - i = te.VarHandle("i", te.Dtype.Int) - body = te.For.make(i, t0, t100, body) - nest = te.LoopNest(body, [b.data()]) - nest.prepare_for_codegen() - f = te.construct_codegen("llvm", nest.simplify(), [a, b]) - ta, tb = [torch.ones(1) for _ in range(2)] - f.call([ta.data_ptr(), tb.data_ptr()]) + a, tmp, b = [ + te.Placeholder(name, te.Dtype.Float, [te.ExprHandle.int(1)]) + for name in ["a", "tmp", "b"]] + t0, t100 = [te.ExprHandle.int(n) for n in [0, 100]] + body = te.Block([ + tmp.store([t0], a.load([t0])), + b.store([t0], tmp.load([t0])) + ]) + for _ in range(4): + i = te.VarHandle("i", te.Dtype.Int) + body = te.For.make(i, t0, t100, body) + nest = te.LoopNest(body, [b.data()]) + nest.prepare_for_codegen() + f = te.construct_codegen("llvm", nest.simplify(), [a, b]) + ta, tb = [torch.ones(1) for _ in range(2)] + f.call([ta.data_ptr(), tb.data_ptr()]) if __name__ == '__main__': run_tests() diff --git a/test/test_testing.py b/test/test_testing.py index d59290b36c27b..e45977f3a855e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -10,20 +10,22 @@ import torch +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest) + (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, deviceCountAtLeast) from torch.testing._internal.common_methods_invocations import op_db import torch.testing._internal.opinfo_helper as opinfo_helper +from torch.testing._internal.common_dtype import get_all_dtypes # For testing TestCase methods and torch.testing functions class TestTesting(TestCase): # Ensure that assertEqual handles numpy arrays properly - @dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False, - include_bool=True, include_complex=True))) + @dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, + include_bool=True, include_complex=True))) def test_assertEqual_numpy(self, device, dtype): S = 10 test_sizes = [ @@ -87,25 +89,19 @@ def test__comparescalars_debug_msg(self, device): "atol=1e-05 is only 1.9100000000000003e-05!") self.assertEqual(debug_msg, expected_msg) - # complex x complex, real difference + # complex x complex result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1)) - expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference " - "of 2.0, but the allowed difference with rtol=1.3e-06 " - "and atol=1e-05 is only 1.39e-05!") - self.assertEqual(debug_msg, expected_msg) - - # complex x complex, imaginary difference - result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5)) - expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a " - "difference of 2.5, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!") + expected_msg = ("Comparing (1+3j) and (3+1j) gives a difference " + "of 2.8284271247461903, but the allowed difference " + "with rtol=1.3e-06 and atol=1e-05 is only " + "1.4110960958218895e-05!") self.assertEqual(debug_msg, expected_msg) # complex x int result, debug_msg = self._compareScalars(complex(1, -2), 1) - expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a " - "difference of 2.0, but the allowed difference with " - "rtol=1.3e-06 and atol=1e-05 is only 1e-05!") + expected_msg = ("Comparing (1-2j) and 1 gives a difference of 2.0, " + "but the allowed difference with rtol=1.3e-06 and " + "atol=1e-05 is only 1.13e-05!") self.assertEqual(debug_msg, expected_msg) # NaN x NaN, equal_nan=False @@ -169,28 +165,6 @@ def test__comparetensors_debug_msg(self, device): "occuring at index 0.") self.assertEqual(debug_msg, expected_msg) - # Checks complex tensor comparisons (real part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 1 + 3j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Real parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - - # Checks complex tensor comparisons (imaginary part) - a = torch.tensor((1 - 1j, 4 + 3j), device=device) - b = torch.tensor((1 - 1j, 4 - 21j), device=device) - result, debug_msg = self._compareTensors(a, b) - expected_msg = ("Imaginary parts failed to compare as equal! " - "With rtol=1.3e-06 and atol={0}, " - "found 1 element(s) (out of 2) whose difference(s) exceeded the " - "margin of error (including 0 nan comparisons). The greatest difference was " - "24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol) - self.assertEqual(debug_msg, expected_msg) - # Checks size mismatch a = torch.tensor((1, 2), device=device) b = torch.tensor((3), device=device) @@ -335,8 +309,6 @@ def test_isclose_comparetensors_float(self, device, dtype): self._comparetensors_helper(tests, device, dtype, True) - # torch.close with equal_nan=True is not implemented for complex inputs - # see https://github.com/numpy/numpy/issues/15959 # Note: compareTensor will compare the real and imaginary parts of a # complex tensors separately, unlike isclose. @dtypes(torch.complex64, torch.complex128) @@ -408,7 +380,7 @@ def test_isclose_comparetensors_complex(self, device, dtype): tests = ( (complex(1, -1), complex(-1, 1), False), (complex(1, -1), complex(2, -2), True), - (complex(1, 99), complex(4, 100), False), + (complex(1, 99), complex(4, 100), True), ) self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5) @@ -416,13 +388,12 @@ def test_isclose_comparetensors_complex(self, device, dtype): # equal_nan = True tests tests = ( (complex(1, 1), complex(1, float('nan')), False), - (complex(float('nan'), 1), complex(1, float('nan')), False), + (complex(1, 1), complex(float('nan'), 1), False), (complex(float('nan'), 1), complex(float('nan'), 1), True), + (complex(float('nan'), 1), complex(1, float('nan')), True), + (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True), ) - - with self.assertRaises(RuntimeError): - self._isclose_helper(tests, device, dtype, True) - + self._isclose_helper(tests, device, dtype, True) self._comparetensors_helper(tests, device, dtype, True) # Tests that isclose with rtol or atol values less than zero throws a @@ -449,6 +420,19 @@ def test_isclose_equality_shortcut(self): self.assertFalse(torch.isclose(a, b, rtol=0, atol=0)) + @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_isclose_nan_equality_shortcut(self, device, dtype): + if dtype.is_floating_point: + a = b = torch.nan + else: + a = complex(torch.nan, 0) + b = complex(0, torch.nan) + + expected = True + tests = [(a, b, expected)] + + self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0) + @dtypes(torch.bool, torch.long, torch.float, torch.cfloat) def test_make_tensor(self, device, dtype): def check(size, low, high, requires_grad, noncontiguous): @@ -880,20 +864,43 @@ def test_matching_atol(self): for fn in assert_close_with_inputs(actual, expected): fn(rtol=0.0, atol=eps * 2) - def test_matching_nan(self): - actual = torch.tensor(float("NaN")) - expected = actual.clone() + # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058 + # We need to check if this test is still needed or if this behavior is now enabled by default. + def test_matching_conjugate_bit(self): + actual = torch.tensor(complex(1, 1)).conj() + expected = torch.tensor(complex(1, -1)) for fn in assert_close_with_inputs(actual, expected): - with self.assertRaises(AssertionError): - fn() + fn() + + def test_matching_nan(self): + nan = float("NaN") + + tests = ( + (nan, nan), + (complex(nan, 0), complex(0, nan)), + (complex(nan, nan), complex(nan, 0)), + (complex(nan, nan), complex(nan, nan)), + ) + + for actual, expected in tests: + for fn in assert_close_with_inputs(actual, expected): + with self.assertRaises(AssertionError): + fn() def test_matching_nan_with_equal_nan(self): - actual = torch.tensor(float("NaN")) - expected = actual.clone() + nan = float("NaN") - for fn in assert_close_with_inputs(actual, expected): - fn(equal_nan=True) + tests = ( + (nan, nan), + (complex(nan, 0), complex(0, nan)), + (complex(nan, nan), complex(nan, 0)), + (complex(nan, nan), complex(nan, nan)), + ) + + for actual, expected in tests: + for fn in assert_close_with_inputs(actual, expected): + fn(equal_nan=True) def test_numpy(self): tensor = torch.rand(2, 2, dtype=torch.float32) @@ -1198,30 +1205,6 @@ def test_mapping_mismatching_values_msg(self): torch.testing.assert_close(actual, expected) -class TestAssertCloseComplex(TestCase): - def test_mismatching_nan_with_equal_nan(self): - actual = torch.tensor(complex(1, float("NaN"))) - expected = torch.tensor(complex(float("NaN"), 1)) - - for fn in assert_close_with_inputs(actual, expected): - with self.assertRaises(AssertionError): - fn(equal_nan=True) - - def test_mismatching_nan_with_equal_nan_relaxed(self): - actual = torch.tensor(complex(1, float("NaN"))) - expected = torch.tensor(complex(float("NaN"), 1)) - - for fn in assert_close_with_inputs(actual, expected): - fn(equal_nan="relaxed") - - def test_matching_conjugate_bit(self): - actual = torch.tensor(complex(1, 1)).conj() - expected = torch.tensor(complex(1, -1)) - - for fn in assert_close_with_inputs(actual, expected): - fn() - - class TestAssertCloseSparseCOO(TestCase): def test_matching_coalesced(self): indices = ( diff --git a/test/test_throughput_benchmark.py b/test/test_throughput_benchmark.py index 9d60344b5912b..139ca0c4cc559 100644 --- a/test/test_throughput_benchmark.py +++ b/test/test_throughput_benchmark.py @@ -1,7 +1,6 @@ import torch from torch.utils import ThroughputBenchmark -from torch.testing import assert_allclose from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName @@ -56,7 +55,7 @@ def linear_test(self, Module, profiler_output_path=""): # or just unpack the list of inputs module_result = module(*inputs[i]) bench_result = bench.run_once(*inputs[i]) - assert_allclose(bench_result, module_result) + torch.testing.assert_close(bench_result, module_result) stats = bench.benchmark( num_calling_threads=4, diff --git a/test/test_torch.py b/test/test_torch.py index 6766d50e6425d..6de409be60d1d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -27,13 +27,14 @@ from itertools import product, combinations, permutations from functools import partial from torch import multiprocessing as mp +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, TEST_WITH_ROCM, run_tests, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, - wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, make_tensor) + wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, @@ -47,6 +48,9 @@ import torch.backends.quantized import torch.testing._internal.data from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 +from torch.testing._internal.common_dtype import ( + get_all_fp_dtypes, get_all_int_dtypes, get_all_math_dtypes, get_all_dtypes, get_all_complex_dtypes +) # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 @@ -221,10 +225,10 @@ def test_namespace(ns, *skips): # TODO: add torch.* tests when we have proper namespacing on ATen functions # test_namespace(torch) - def test_msnpu_error(self): + def test_ort_error(self): with self.assertRaisesRegex(RuntimeError, - "Could not run 'aten::empty.memory_format' with arguments from the 'MSNPU' backend"): - torch.zeros(1, device=torch.device('msnpu')) + "Could not run 'aten::empty.memory_format' with arguments from the 'ORT' backend"): + torch.zeros(1, device=torch.device('ort')) def test_has_storage(self): self.assertIsNotNone(torch.tensor([]).storage()) @@ -273,8 +277,8 @@ def get_tensor(size, dtype, device, contiguous): height = 5 width = 5 for device in torch.testing.get_all_device_types(): - for dt1 in torch.testing.get_all_dtypes(): - for dt2 in torch.testing.get_all_dtypes(): + for dt1 in get_all_dtypes(): + for dt2 in get_all_dtypes(): for contiguous in [True, False]: x1 = get_tensor((height, width), dt1, device, contiguous) x2 = get_tensor((height, width), dt2, device, contiguous) @@ -292,14 +296,14 @@ def get_tensor(size, dtype, device, contiguous): self.assertEqual(expected, result) def test_dtypes(self): - all_dtypes = torch.testing.get_all_dtypes() + all_dtypes = get_all_dtypes() do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu')) if torch.cuda.is_available(): all_dtypes.remove(torch.bfloat16) # Remove once _th_zero_ is enabled on cuda for bfloat16 do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0')) def test_copy_dtypes(self): - all_dtypes = torch.testing.get_all_dtypes() + all_dtypes = get_all_dtypes() for dtype in all_dtypes: copied_dtype = copy.deepcopy(dtype) self.assertIs(dtype, copied_dtype) @@ -323,6 +327,14 @@ def test_copy_transpose(self): self.assertEqual(y[:, 0], range(100)) self.assertEqual(y[:, 40], range(4000, 4100)) + # Verifies the bugfix for https://github.com/pytorch/pytorch/issues/64358 + def test_copy_transpose_2d_broadcast(self): + # The shape (60, 60) is chosen because of + # `MIN_SZ = 60 * 60` in `copy_transpose_valid` from aten/src/ATen/native/Copy.cpp + A = torch.randn(60, 60) + A.copy_(torch.tensor([[1.]])) + self.assertEqual(A, torch.ones(60, 60)) + def test_device(self): cpu = torch.device('cpu') self.assertEqual('cpu', str(cpu)) @@ -713,7 +725,7 @@ def reference(x, k, o3, o32): self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k, 'F'), reference) def test_dtype_is_signed(self): - for dtype in torch.testing.get_all_dtypes(): + for dtype in get_all_dtypes(): self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype))) self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.quint8.is_signed) @@ -950,7 +962,7 @@ def test_index_add(self): # https://github.com/pytorch/pytorch/issues/29153 def test_index_add_all_dtypes(self): for device in torch.testing.get_all_device_types(): - for dtype in torch.testing.get_all_math_dtypes(device): + for dtype in get_all_math_dtypes(device): for idx_dtype in [torch.int, torch.long]: size = [5, 5] if dtype.is_floating_point or dtype.is_complex: @@ -1056,8 +1068,11 @@ def _test_gather(self, cast, test_bounds=True): torch.gather(src, dim, idx.to(torch.int)) # should throw an error when out.dtype != src.dtype. - with self.assertRaisesRegex(RuntimeError, 'Expected self.dtype to be equal to src.dtype'): - torch.gather(src, dim, idx, out=expected.to(torch.int)) + # Note that on Windows, the out tensor's dtype is returned as: struct c10::complex in the error + # message, hence the use of .* in regex here + with self.assertRaisesRegex(RuntimeError, + 'Expected out tensor to have dtype .*c10::complex, but got int instead'): + torch.gather(src.to(torch.complex128), dim, idx, out=expected.to(torch.int)) # checks for the same dimensionality with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as input tensor'): @@ -1566,7 +1581,7 @@ def test_sobolengine_continuing(self, scramble: bool = False): n_half = len(ref_sample) // 2 _ = engine.draw(n=n_half) sample = engine.draw(n=n_half) - torch.testing.assert_allclose(sample, ref_sample[n_half:]) + torch.testing.assert_close(sample, ref_sample[n_half:]) def test_sobolengine_continuing_scrambled(self): self.test_sobolengine_continuing(scramble=True) @@ -1578,7 +1593,7 @@ def test_sobolengine_reset(self, scramble: bool = False): engine.reset() self.assertEqual(engine.num_generated, 0) sample = engine.draw(n=len(ref_sample)) - torch.testing.assert_allclose(sample, ref_sample) + torch.testing.assert_close(sample, ref_sample) def test_sobolengine_reset_scrambled(self): self.test_sobolengine_reset(scramble=True) @@ -1588,7 +1603,7 @@ def test_sobolengine_fast_forward(self, scramble: bool = False): engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) engine.fast_forward(4) sample = engine.draw(n=4) - torch.testing.assert_allclose(sample, ref_sample[4:]) + torch.testing.assert_close(sample, ref_sample[4:]) # alternate fast forwarding with sampling engine.reset() even_draws = [] @@ -1597,9 +1612,9 @@ def test_sobolengine_fast_forward(self, scramble: bool = False): even_draws.append(engine.draw()) else: engine.fast_forward(1) - torch.testing.assert_allclose( + torch.testing.assert_close( ref_sample[[i for i in range(8) if i % 2 == 0]], - np.concatenate(even_draws), + torch.from_numpy(np.concatenate(even_draws)), ) def test_sobolengine_fast_forward_scrambled(self): @@ -1609,13 +1624,13 @@ def test_sobolengine_distribution(self, scramble=False): d = 50 engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456) sample = engine.draw(1024) - torch.testing.assert_allclose( + torch.testing.assert_close( torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2 ) - torch.testing.assert_allclose( + torch.testing.assert_close( np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2 ) - torch.testing.assert_allclose( + torch.testing.assert_close( np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2 ) @@ -2440,7 +2455,7 @@ def test_c10_layer_norm(self): actual_norm, actual_mean, actual_stdev = \ torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( weight), torch.tensor(bias), 1, epsilon, True) - torch.testing.assert_allclose(expected_norm, actual_norm) + torch.testing.assert_close(expected_norm, actual_norm) def test_memory_format(self): def test_helper(x, memory_format): @@ -4285,13 +4300,13 @@ def _cond_fn(x): _sync_raises_helper(f, level) - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_log_normal(self, device, dtype): a = torch.tensor([10], dtype=dtype, device=device).log_normal_() self.assertEqual(a.dtype, dtype) self.assertEqual(a.size(), torch.Size([1])) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_geometric(self, device, dtype): a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5) self.assertEqual(a.dtype, dtype) @@ -4323,8 +4338,9 @@ def test_repeat_interleave(self, device): self.assertEqual(a_with_output.dtype, y.dtype) self.assertEqual(a_with_output.size(), torch.Size([3, 2])) - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) - @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) + @dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True))) + @dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) def test_bernoulli_p(self, device, dtype): for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): x = torch.tensor(trivial_p, dtype=dtype, device=device) @@ -4344,9 +4360,9 @@ def isBinary(t): self.assertTrue(isBinary(p)) # RngUniform not implemented for Integral type in XLA test - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) - @dtypesIfCPU(*(torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False))) - @dtypesIfCUDA(*(torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False))) + @dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) + @dtypesIfCPU(*(get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False))) + @dtypesIfCUDA(*(get_all_dtypes(include_bfloat16=False, include_complex=False))) def test_bernoulli_self(self, device, dtype): def isBinary(t): @@ -4358,8 +4374,7 @@ def isBinary(t): t.bernoulli_(0.5) self.assertTrue(isBinary(t)) - for p_dtype in torch.testing.get_all_fp_dtypes(include_half=device.startswith('cuda'), - include_bfloat16=False): + for p_dtype in get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10) t.fill_(2) t.bernoulli_(p) @@ -4374,8 +4389,8 @@ def isBinary(t): self.assertTrue(isBinary(t)) @slowTest - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) - @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) + @dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) def test_bernoulli_edge_cases(self, device, dtype): # Need to draw a lot of samples to cover every random floating point number. a = torch.zeros(10000, 10000, dtype=dtype, device=device) # probability of drawing "1" is 0 @@ -4386,7 +4401,7 @@ def test_bernoulli_edge_cases(self, device, dtype): num_zeros = (torch.bernoulli(b) == 0).sum() self.assertEqual(num_zeros, 0) - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_exponential(self, device, dtype): a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5) self.assertEqual(a.dtype, dtype) @@ -4469,7 +4484,7 @@ def check(msg, *args, **kwargs): check(r'aweights cannot be negative', a, aweights=torch.tensor([-1., -2.])) @skipIfNoSciPy - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_uniform_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4481,8 +4496,8 @@ def test_uniform_kstest(self, device, dtype): self.assertTrue(res.statistic < 0.1) @skipIfNoSciPy - @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False)) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes(include_bfloat16=False)) + @dtypesIfCUDA(*get_all_fp_dtypes()) def test_normal_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4493,7 +4508,7 @@ def test_normal_kstest(self, device, dtype): self.assertTrue(res.statistic < 0.1) @skipIfNoSciPy - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_lognormal_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4507,7 +4522,7 @@ def test_lognormal_kstest(self, device, dtype): self.assertTrue(res.statistic < 0.1) @skipIfNoSciPy - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_exponential_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4517,7 +4532,7 @@ def test_exponential_kstest(self, device, dtype): self.assertTrue(res.statistic < 0.1) @skipIfNoSciPy - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_cauchy_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4538,7 +4553,7 @@ def test_cauchy_no_inf(self, device, dtype): self.assertFalse(x.isinf().sum()) @skipIfNoSciPy - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_geometric_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -4992,7 +5007,7 @@ def to_np(t): # All tensors appear contiguous on XLA @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_diff_noncontig(self, device, dtype): shapes = ( (1,), @@ -5012,9 +5027,9 @@ def test_diff_noncontig(self, device, dtype): self._test_diff_numpy(non_contig) # RngNormal not implemented for type f16 for XLA - @dtypes(*torch.testing.get_all_dtypes(include_half=False)) - @dtypesIfCPU(*torch.testing.get_all_dtypes()) - @dtypesIfCUDA(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes(include_half=False)) + @dtypesIfCPU(*get_all_dtypes()) + @dtypesIfCUDA(*get_all_dtypes()) def test_diff(self, device, dtype): shapes = ( (1,), @@ -5119,7 +5134,7 @@ def filter_shape(shape, dim): spacing = [space.cpu().detach().numpy() for space in spacing] expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order) actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected)) - self.assertEqual(actual, expected, equal_nan="relaxed", atol=1e-4, rtol=0, exact_dtype=False) + self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False) @onlyOnCPUAndCUDA @dtypes(torch.long, torch.float32, torch.complex64) @@ -5186,7 +5201,7 @@ def test_gradient_type_promotion(self, device): self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False) else: actual, expected = self._inf_nan_preprocess(list(actual), expected) - self.assertEqual(actual, expected, equal_nan="relaxed", exact_dtype=False) + self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False) @onlyOnCPUAndCUDA @dtypes(torch.long, torch.float32, torch.complex64) @@ -5285,7 +5300,7 @@ def test_bool_tensor_value_change(self, device): self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device)) def test_unfold_all_devices_and_dtypes(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): if dt == torch.bool: x = torch.empty((0, 1, 3, 0), dtype=dt, device=device) @@ -5305,7 +5320,7 @@ def test_unfold_scalars(self, device): def test_copy_all_dtypes_and_devices(self, device): from copy import copy - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) x_clone = x.clone() y = copy(x) @@ -5315,7 +5330,7 @@ def test_copy_all_dtypes_and_devices(self, device): self.assertEqual(x, y) def test_clone_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor((1, 1), dtype=dt, device=device) y = x.clone() self.assertEqual(x, y) @@ -5326,8 +5341,15 @@ def test_clone_zero_stride_dim(self, device): y = x.as_strided([2, 1, 5], [1, 0, 2]) self.assertEqual(y, y.clone()) - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda'))) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu'))) + def test_clone_not_memory_dense(self): + # github issue: https://github.com/pytorch/pytorch/issues/64176 + x = torch.randn(10, 8).t()[::2, ::2] + y = x.clone() + # should retain permutation after densification + self.assertTrue(y.stride() == (1, 4)) + + @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) + @dtypes(*set(get_all_math_dtypes('cpu'))) def test_addcmul(self, device, dtype): # Returns floating or integral scalar corresponding to dtype def _number(floating, integer, dtype): @@ -5376,7 +5398,7 @@ def test_narrow_empty(self, device): sz[d] = 0 self.assertEqual(sz, y.size()) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_index_copy(self, device, dtype): # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined @@ -5410,7 +5432,7 @@ def ref_index_copy(tgt, dim, idx, src): # onlyOnCPUAndCUDA due to an XLA error: # https://github.com/pytorch/pytorch/issues/53256 @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_index_copy_scalars(self, device, dtype): # Create the 8 possible combinations of scalar sizes for target / index / source scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None), @@ -5515,7 +5537,7 @@ def test_index_put_non_accumulate_deterministic(self, device) -> None: self.assertEqual(output, input_list) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_index_fill(self, device, dtype): x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device) index = torch.tensor([0], device=device) @@ -5532,7 +5554,7 @@ def test_index_fill(self, device, dtype): # The test fails for zero-dimensional tensors on XLA @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_index_select(self, device, dtype): num_src, num_out = 3, 5 @@ -5575,7 +5597,7 @@ def ref_index_select(src, dim, idx): out = source.index_select(0, idx) self.assertEqual(out.item(), source.item()) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_take(self, device, dtype): idx_size = (4,) @@ -5609,7 +5631,7 @@ def ref_take(src, idx): # The bool instance does not work on GPU. See # https://github.com/pytorch/pytorch/issues/54317 - @dtypes(*torch.testing.get_all_dtypes(include_bool=False)) + @dtypes(*get_all_dtypes(include_bool=False)) def test_put(self, device, dtype): src_size = (4,) @@ -5679,7 +5701,7 @@ def ref_put(dst, idx, src, accumulate): # The bool instance does not work on GPU. See # https://github.com/pytorch/pytorch/issues/54317 - @dtypes(*torch.testing.get_all_dtypes(include_bool=False)) + @dtypes(*get_all_dtypes(include_bool=False)) def test_put_accumulate(self, device, dtype): # Test for parallel adds with accumulate == True low_precision = dtype == torch.half or dtype == torch.bfloat16 @@ -5722,10 +5744,10 @@ def scatter_allow_reduce(self, device, dtype, reduceop): # torch.{zeros, ones} do not support ComplexHalf (torch.complex32) # So, we are skipping it here. - @dtypes(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False, include_half=False) + - torch.testing.get_all_complex_dtypes())) - @dtypesIfCPU(*torch.testing.get_all_dtypes()) - @dtypesIfCUDA(*torch.testing.get_all_dtypes()) + @dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + + get_all_complex_dtypes())) + @dtypesIfCPU(*get_all_dtypes()) + @dtypesIfCUDA(*get_all_dtypes()) def test_scatter_reduce_operations_to_large_input(self, device, dtype): index = torch.tensor([[1], [2]], device=device, dtype=torch.long) test_data = [ @@ -5752,10 +5774,10 @@ def test_scatter_reduce_operations_to_large_input(self, device, dtype): # torch.{zeros, ones} do not support ComplexHalf (torch.complex32) # So, we are skipping it here. - @dtypes(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False, include_half=False) + - torch.testing.get_all_complex_dtypes())) - @dtypesIfCPU(*torch.testing.get_all_dtypes()) - @dtypesIfCUDA(*torch.testing.get_all_dtypes()) + @dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + + get_all_complex_dtypes())) + @dtypesIfCPU(*get_all_dtypes()) + @dtypesIfCUDA(*get_all_dtypes()) def test_scatter_reduce_scalar(self, device, dtype): index = torch.tensor([[1], [2]], device=device, dtype=torch.long) test_data = [ @@ -5793,10 +5815,10 @@ def test_scatter_add_non_unique_index(self, device): # torch.{zeros, ones} do not support ComplexHalf (torch.complex32) # So, we are skipping it here. - @dtypes(*(torch.testing.get_all_fp_dtypes(include_bfloat16=False, include_half=False) + - torch.testing.get_all_complex_dtypes())) - @dtypesIfCPU(*torch.testing.get_all_dtypes()) - @dtypesIfCUDA(*torch.testing.get_all_dtypes()) + @dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + + get_all_complex_dtypes())) + @dtypesIfCPU(*get_all_dtypes()) + @dtypesIfCUDA(*get_all_dtypes()) def test_scatter_reduce_non_unique_index(self, device, dtype): height = 2 width = 2 @@ -5820,8 +5842,8 @@ def test_scatter_reduce_non_unique_index(self, device, dtype): # torch.{zeros, ones} do not support ComplexHalf (torch.complex32) # So, we are skipping it here. @onlyCUDA - @dtypesIfCUDA(*(torch.testing.get_all_complex_dtypes() + - torch.testing.get_all_int_dtypes())) + @dtypesIfCUDA(*(get_all_complex_dtypes() + + get_all_int_dtypes())) def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype): height = 2 width = 2 @@ -5868,7 +5890,7 @@ def test_scatter_add_bool(self, device): [True, False, True, False, True]], device=device)) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_masked_scatter(self, device, dtype): dt = dtype with warnings.catch_warnings(record=True) as w: @@ -5953,7 +5975,7 @@ def test_masked_scatter_large_tensor(self, device): result = t.masked_scatter(t, t) self.assertEqual(result, result_cpu) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_masked_select(self, device, dtype): if device == 'cpu': warn = 'masked_select received a mask with dtype torch.uint8,' @@ -6011,15 +6033,15 @@ def test_masked_select_discontiguous(self, device): out_dc = torch.empty(size * size, device=device)[::2] for v, m in product(vals_list, mask_list): if m.is_contiguous(): - expected = v[:, ::2].clone().view(-1) + expected = v[:, ::2].clone().reshape((-1, )) else: - expected = v[::2].clone().view(-1) + expected = v[::2].clone().reshape((-1, )) out = torch.masked_select(v, m) self.assertEqual(out, expected, atol=0, rtol=0) torch.masked_select(v, m, out=out_dc) self.assertEqual(out_dc, expected, atol=0, rtol=0) - @dtypes(*product(torch.testing.get_all_dtypes(), (torch.uint8, torch.bool))) + @dtypes(*product(get_all_dtypes(), (torch.uint8, torch.bool))) def test_masked_fill(self, device, dtypes): dtype = dtypes[0] mask_dtype = dtypes[1] @@ -6329,8 +6351,8 @@ def test_pdist_norm_large(self, device): self.assertEqual(expected_cpu, actual_gpu.cpu()) @onlyOnCPUAndCUDA - @dtypesIfCUDA(*set(torch.testing.get_all_math_dtypes('cuda'))) - @dtypes(*set(torch.testing.get_all_math_dtypes('cpu'))) + @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) + @dtypes(*set(get_all_math_dtypes('cpu'))) def test_addcdiv(self, device, dtype): # Returns floating or integral scalar corresponding to dtype def _number(floating, integer, dtype): @@ -7073,7 +7095,7 @@ def compare_strides(s1, s2, div): _test_helper(x, op, unary=True) @skipMeta - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_dlpack_conversion(self, device, dtype): # DLpack does not explicitly support bool # It does it through uint8 type @@ -7673,8 +7695,8 @@ def _where_valid_scalar_tensor_combination(self, scalar_type, dtype): return False @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + + get_all_complex_dtypes())) def test_where_scalar_invalid_combination_raises(self, device, dtype): def checkRaises(scalar_type, dtype, condition, x, scalar_1): @@ -7686,8 +7708,8 @@ def checkRaises(scalar_type, dtype, condition, x, scalar_1): self._test_where_scalar_template(device, dtype, checkRaises) @skipCUDAVersionIn([(11, 2)]) # test fails for 11.2, see https://github.com/pytorch/pytorch/issues/51980 - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes() + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + + get_all_complex_dtypes())) def test_where_scalar_valid_combination(self, device, dtype): def checkResult(scalar_type, dtype, condition, x, scalar_1): diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index bd48e38045a13..81411c058bca6 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -8,6 +8,9 @@ TEST_NUMPY, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyOnCPUAndCUDA, dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta) +from torch.testing._internal.common_dtype import ( + get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes +) if TEST_NUMPY: import numpy as np @@ -179,7 +182,7 @@ def test_bfloat16(self, device): self.assertEqual(bf + scalar, scalar + bf) # with tensor - for dtype in torch.testing.get_all_dtypes(): + for dtype in get_all_dtypes(): t = torch.tensor(1, dtype=dtype, device=device) self.assertEqual(bf + t, t + bf) if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble): @@ -254,8 +257,8 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False): def test_many_promotions(self, device): # Can also include half on CPU in cases where it will be promoted to a # supported dtype - dtypes1 = torch.testing.get_all_math_dtypes('cuda') - dtypes2 = torch.testing.get_all_math_dtypes(device) + dtypes1 = get_all_math_dtypes('cuda') + dtypes2 = get_all_math_dtypes(device) ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub] for dt1, dt2 in itertools.product(dtypes1, dtypes2): for op, non_contiguous in itertools.product(ops, [True, False]): @@ -331,7 +334,7 @@ def test_create_bool_tensors(self, device): # this seems like odd behavior but ints also create float tensors, numpy doesn't have this function. self.assertEqual(torch.scalar_tensor(False, device=device), torch.tensor(0., device=device)) - @dtypes(*itertools.product(torch.testing.get_all_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*itertools.product(get_all_dtypes(), get_all_dtypes())) def test_result_type(self, device, dtypes): "Test result_type for tensor vs tensor and scalar vs scalar." @@ -460,8 +463,8 @@ def test_comparison_ops_with_type_promotion(self, device): ), ] for op in comparison_ops: - for dt1 in torch.testing.get_all_math_dtypes(device): - for dt2 in torch.testing.get_all_math_dtypes(device): + for dt1 in get_all_math_dtypes(device): + for dt2 in get_all_math_dtypes(device): if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"): continue val1 = value_for_type[dt1] @@ -511,8 +514,8 @@ def test_complex_assertraises(self, device): dict(name="ne", compare_op=lambda x, y: x != y, ), ] for op in comparison_ops: - for dt1 in torch.testing.get_all_math_dtypes(device): - for dt2 in torch.testing.get_all_math_dtypes(device): + for dt1 in get_all_math_dtypes(device): + for dt2 in get_all_math_dtypes(device): if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"): u = torch.tensor([1], dtype=dt1, device=device) v = torch.tensor([2], dtype=dt2, device=device) @@ -520,7 +523,7 @@ def test_complex_assertraises(self, device): @float_double_default_dtype def test_lt_with_type_promotion(self, device): - for dt in torch.testing.get_all_math_dtypes(device): + for dt in get_all_math_dtypes(device): x = torch.tensor([0], dtype=dt, device=device) expected = torch.tensor([True], dtype=torch.bool, device=device) @@ -553,7 +556,7 @@ def test_promote_types(self, device): @float_double_default_dtype def test_promote_self(self, device): - for dtype in torch.testing.get_all_dtypes(): + for dtype in get_all_dtypes(): self.assertEqual(torch.promote_types(dtype, dtype), dtype) @expectedFailureMeta @@ -758,12 +761,12 @@ def _run_all_tests_for_sparse_op(self, op_name, device, dtypes): @onlyOnCPUAndCUDA def test_sparse_add(self, device): self._run_all_tests_for_sparse_op('add', device, - dtypes=torch.testing.get_all_math_dtypes(device)) + dtypes=get_all_math_dtypes(device)) @onlyOnCPUAndCUDA def test_sparse_mul(self, device): self._run_all_tests_for_sparse_op('mul', device, - dtypes=torch.testing.get_all_math_dtypes(device)) + dtypes=get_all_math_dtypes(device)) @onlyOnCPUAndCUDA def test_sparse_div(self, device): @@ -774,7 +777,7 @@ def test_sparse_div(self, device): @onlyOnCPUAndCUDA def test_sparse_sub(self, device): self._run_all_tests_for_sparse_op('sub', device, - dtypes=torch.testing.get_all_math_dtypes(device)) + dtypes=get_all_math_dtypes(device)) @onlyOnCPUAndCUDA @dtypes(torch.bool, torch.short, torch.uint8, torch.int, torch.long) @@ -871,7 +874,7 @@ def test_numpy_array_binary_ufunc_promotion(self, device, dtypes): @onlyOnCPUAndCUDA def test_cat_different_dtypes(self, device): - dtypes = torch.testing.get_all_dtypes(include_bfloat16=False) + dtypes = get_all_dtypes(include_bfloat16=False) for x_dtype, y_dtype in itertools.product(dtypes, dtypes): x_vals, y_vals = [1, 2, 3], [4, 5, 6] @@ -890,7 +893,7 @@ def test_cat_different_dtypes(self, device): @onlyOnCPUAndCUDA def test_cat_out_different_dtypes(self, device): - dtypes = torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False) + dtypes = get_all_dtypes(include_bfloat16=False, include_bool=False) for x_dtype, y_dtype, out_dtype in itertools.product(dtypes, dtypes, dtypes): out = torch.zeros(6, device=device, dtype=out_dtype) x = torch.tensor([1, 2, 3], device=device, dtype=x_dtype) @@ -957,21 +960,21 @@ def test_computation_ignores_out(self, device): self.assertEqual(result, a - b, exact_dtype=False) self.assertNotEqual(result, a.double() - b, exact_dtype=False) - @dtypesIfCUDA(*itertools.product(torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False), - torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False))) - @dtypes(*itertools.product(torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False, - include_complex=False), - torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False, - include_complex=False))) + @dtypesIfCUDA(*itertools.product(get_all_dtypes(include_bfloat16=False, include_complex=False), + get_all_dtypes(include_bfloat16=False, include_complex=False))) + @dtypes(*itertools.product(get_all_dtypes(include_half=False, include_bfloat16=False, + include_complex=False), + get_all_dtypes(include_half=False, include_bfloat16=False, + include_complex=False))) def test_atan2_type_promotion(self, device, dtypes): dtype1, dtype2 = dtypes default_float = torch.get_default_dtype() def is_int(dtype): - return dtype in torch.testing.get_all_int_dtypes() + [torch.bool] + return dtype in get_all_int_dtypes() + [torch.bool] def is_float(dtype): - return dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False) + return dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False) def get_binary_float_result_type(x, y): dtype1 = x.dtype diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index e5b8c4a66093b..c65ae980fd82a 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -11,15 +11,18 @@ from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, - suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS) + suppress_warnings, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS) from torch.testing._internal.common_methods_invocations import ( unary_ufuncs, _NOTHING) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, dtypes, onlyCPU, onlyOnCPUAndCUDA, onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU, OpDTypes) -from torch.testing import ( - floating_types_and, all_types_and_complex_and, floating_and_complex_types_and) +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import ( + floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, get_all_dtypes, get_all_math_dtypes, + get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes +) if TEST_SCIPY: import scipy @@ -359,10 +362,7 @@ def test_reference_numerics_extremal(self, device, dtype, op): tensors = generate_numeric_tensors_extremal(device, dtype, domain=op.domain) - # https://github.com/pytorch/pytorch/issues/50749 - equal_nan = "relaxed" if device.startswith('cuda') else True - - self._test_reference_numerics(dtype, op, tensors, equal_nan) + self._test_reference_numerics(dtype, op, tensors) # Tests for testing (non)contiguity consistency @@ -505,8 +505,8 @@ def test_out_arg_all_dtypes(self, device, dtype, op): out = torch.empty_like(input, dtype=out_dtype) self._test_out_arg(op, input, out, expected, **torch_kwargs) - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool] + - torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + @dtypes(*(get_all_int_dtypes() + [torch.bool] + + get_all_fp_dtypes(include_bfloat16=False))) def test_nan_to_num(self, device, dtype): for contiguous in [False, True]: x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device) @@ -584,7 +584,7 @@ def test_digamma(self, device, dtype): self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) @skipCUDAIfRocm - @dtypes(*torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)) + @dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False)) def test_frexp(self, device, dtype): input = make_tensor((50, 50), device, dtype) mantissa, exponent = torch.frexp(input) @@ -598,7 +598,7 @@ def test_frexp(self, device, dtype): self.assertTrue(torch_to_numpy_dtype_dict[exponent.dtype] == np_exponent.dtype) @skipCUDAIfRocm - @dtypes(*torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)) + @dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False)) def test_frexp_out(self, device, dtype): input = make_tensor((50, 50), device, dtype) outputs = ( @@ -625,20 +625,18 @@ def test_frexp_out(self, device, dtype): @skipCUDAIfRocm def test_frexp_assert_raises(self, device): - invalid_input_dtypes = torch.testing.get_all_int_dtypes() + \ - torch.testing.get_all_complex_dtypes() + \ + invalid_input_dtypes = get_all_int_dtypes() + \ + get_all_complex_dtypes() + \ [torch.bool] for dtype in invalid_input_dtypes: input = make_tensor((50, 50), device, dtype) with self.assertRaisesRegex(RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"): torch.frexp(input) - for dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False): + for dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False): input = make_tensor((50, 50), device, dtype) - dtypes = list(torch.testing.all_types_and_complex_and(torch.bool, - torch.half, - torch.bfloat16)) + dtypes = list(all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) dtypes.remove(dtype) for mantissa_dtype in dtypes: mantissa = torch.empty_like(input, dtype=mantissa_dtype) @@ -1049,7 +1047,7 @@ def test_mish(self, device, dtype): # do ops like threshold need a test_unary(_nonufunc) test suite? @onlyCPU - @dtypes(*torch.testing.get_all_math_dtypes('cpu')) + @dtypes(*get_all_math_dtypes('cpu')) def test_threshold(self, device, dtype): if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: # 100 is wide enough to use AVX2 instructions for all types @@ -1183,7 +1181,7 @@ def _i0_range_helper(self, range, device, dtype): t = torch.rand(1000, device=device).to(dtype) * r self._i0_helper(t) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_fp_dtypes()) @dtypes(torch.bfloat16, torch.float32, torch.float64) @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_i0_range1(self, device, dtype): @@ -1191,7 +1189,7 @@ def test_i0_range1(self, device, dtype): # The domain is (-13.25, 13.25) self._i0_range_helper(13.25, device, dtype) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_fp_dtypes()) @dtypes(torch.bfloat16, torch.float32, torch.float64) @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_i0_range2(self, device, dtype): @@ -1206,7 +1204,7 @@ def test_i0_range3(self, device, dtype): # The domain is (-709.75, 709.75) self._i0_range_helper(709.75, device, dtype) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_fp_dtypes()) @dtypes(torch.bfloat16, torch.float32, torch.float64) @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_i0_special(self, device, dtype): @@ -1216,7 +1214,7 @@ def test_i0_special(self, device, dtype): t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) self.assertTrue(torch.i0(t).isnan().all()) - @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes()) + @dtypesIfCUDA(*get_all_fp_dtypes()) @dtypes(torch.bfloat16, torch.float32, torch.float64) @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_special_i0_i1_vs_scipy(self, device, dtype): @@ -1308,7 +1306,7 @@ def test_abs_zero(self, device, dtype): for num in abs_zeros: self.assertGreater(math.copysign(1.0, num), 0.0) - @dtypes(*torch.testing.get_all_fp_dtypes()) + @dtypes(*get_all_fp_dtypes()) def test_isfinite_isinf_isnan(self, device, dtype): vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) @@ -1324,7 +1322,7 @@ def test_isfinite_isinf_isnan_int(self, device, dtype): self.compare_with_numpy(torch.isinf, np.isinf, vals, device, dtype) self.compare_with_numpy(torch.isnan, np.isnan, vals, device, dtype) - @dtypes(*(torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_fp_dtypes())) def test_isposinf_isneginf_float(self, device, dtype): ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) vals = (-float('inf'), float('inf'), float('nan'), -1, 0, 1) @@ -1349,7 +1347,7 @@ def test_isposinf_isneginf_float(self, device, dtype): torch_op(t, out=out) self.assertEqual(out, t_target) - @dtypes(*(torch.testing.get_all_int_dtypes() + [torch.bool])) + @dtypes(*(get_all_int_dtypes() + [torch.bool])) def test_isposinf_isneginf_int_and_bool(self, device, dtype): ops = ((torch.isposinf, np.isposinf), (torch.isneginf, np.isneginf)) vals = (-1, 0, 1) @@ -1377,7 +1375,7 @@ def test_isposinf_isneginf_complex(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'does not support complex inputs'): torch_op(t, out=out) - @dtypes(*(torch.testing.get_all_dtypes(include_bool=False))) + @dtypes(*(get_all_dtypes(include_bool=False))) def test_isposinf_isneginf_non_boolean_output(self, device, dtype): # test non-boolean tensors as the `out=` parameters # boolean outputs are tested in the above testcases @@ -1409,7 +1407,7 @@ def test_isreal_complex(self, device, dtype): vals = (1, 1 + 1j, 2 + 0j, 3j, 2 - 1j, 2 - 0j) self.compare_with_numpy(torch.isreal, np.isreal, vals, device, dtype) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_isreal_noncomplex(self, device, dtype): vals = (1, 2, 3) # Manual check here since numpy doesn't support bfloat16 @@ -1470,7 +1468,7 @@ def assert_tuple_empty(tup, dim): self.assertEqual(1, len(z)) self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_nonzero_noncontiguous(self, device, dtype): x = make_tensor((10, 10, 10), dtype=dtype, device=device, low=1, noncontiguous=False) @@ -1499,10 +1497,10 @@ def permute_storage(tensor, dims): self.assertEqual(nondense.nonzero(), expect) # TODO: rationalize with exp OpInfo - @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) + - torch.testing.get_all_complex_dtypes())) - @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_fp_dtypes(include_half=False) + + get_all_complex_dtypes())) + @dtypesIfCUDA(*(get_all_fp_dtypes(include_half=True) + + get_all_complex_dtypes())) def test_exp(self, device, dtype): for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): a = torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device) / 3 * math.pi diff --git a/test/test_utils.py b/test/test_utils.py index d0f8d10d9fbd4..6f9432e0e6392 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ import torch.hub as hub from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings -from torch.testing._internal.common_utils import load_tests, retry, IS_SANDCASTLE, IS_WINDOWS, has_breakpad +from torch.testing._internal.common_utils import has_breakpad, load_tests, retry, IS_SANDCASTLE, IS_WINDOWS, TEST_WITH_ASAN from urllib.error import URLError # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for @@ -739,7 +739,8 @@ def forward(self, x): class TestCrashHandler(TestCase): - @unittest.skipIf(not has_breakpad(), "Crash handler lib was not linked in") + @unittest.skipIf(TEST_WITH_ASAN, "ASAN disables the crash handler's signal handler") + @unittest.skipIf(not has_breakpad(), "Built without breakpad") def test_python_exception_writing(self): with tempfile.TemporaryDirectory() as temp_dir: torch.utils._crash_handler.enable_minidumps(temp_dir) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 306c6cb411f3f..06aaf31423f3f 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -6,10 +6,14 @@ from functools import partial import random +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (TestCase, run_tests, suppress_warnings, make_tensor) + (TestCase, run_tests, suppress_warnings) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA) +from torch.testing._internal.common_dtype import ( + get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes +) # TODO: replace this with make_tensor() in common_utils.py def _generate_input(shape, dtype, device, with_extremal): @@ -113,14 +117,14 @@ def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): else: return x.transpose(dim0, dim1) - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_conj_self(self, device, dtype): t = torch.ones(5, 5, device=device) s = t.conj() self.assertTrue(s is t) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=False), torch.complex64) + @dtypes(*get_all_fp_dtypes(include_bfloat16=False), torch.complex64) def test_view_dtype(self, device, dtype): int_dtype = { torch.half: torch.int16, @@ -226,7 +230,7 @@ def fn(contiguous_input=True, dim0=0, dim1=1): self.assertEqual(res.shape, torch.Size([0])) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes(include_complex32=True)) + @dtypes(*get_all_complex_dtypes(include_complex32=True)) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): t = torch.randn(3, 4, dtype=dtype, device=device) @@ -264,7 +268,7 @@ def fn(contiguous_input=True): self.assertEqual(res.shape, torch.Size([2])) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_view_tensor_split(self, device, dtype): a = make_tensor((40, 30), device, dtype, low=-9, high=9) a_split_dim0 = a.tensor_split(7, 0) @@ -275,7 +279,7 @@ def test_view_tensor_split(self, device, dtype): self.assertTrue(self.is_view_of(a, a_split_dim1_tensor)) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_view_tensor_hsplit(self, device, dtype): t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9) t_hsplit = torch.hsplit(t, 2) @@ -285,7 +289,7 @@ def test_view_tensor_hsplit(self, device, dtype): self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2]) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_view_tensor_vsplit(self, device, dtype): t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9) t_vsplit = torch.vsplit(t, 2) @@ -295,7 +299,7 @@ def test_view_tensor_vsplit(self, device, dtype): self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2]) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_view_tensor_dsplit(self, device, dtype): t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9) t_dsplit = torch.dsplit(t, 2) @@ -305,7 +309,7 @@ def test_view_tensor_dsplit(self, device, dtype): self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) @onlyOnCPUAndCUDA - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) def test_real_imag_noncomplex(self, device, dtype): t = torch.ones((5, 5), dtype=dtype, device=device) @@ -316,7 +320,7 @@ def test_real_imag_noncomplex(self, device, dtype): torch.imag(t) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_complex_dtypes()) def test_real_imag_view(self, device, dtype): def compare_with_numpy(contiguous_input=True): t = torch.randn(3, 3, dtype=dtype, device=device) @@ -347,7 +351,7 @@ def compare_with_numpy(contiguous_input=True): self.assertEqual(a[5:].imag, a.imag[5:]) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes()) + @dtypes(*get_all_complex_dtypes()) def test_conj_imag_view(self, device, dtype) -> None: t = _make_tensor((4, 5,), dtype, device) t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device) @@ -362,7 +366,7 @@ def test_conj_imag_view(self, device, dtype) -> None: self.assertTrue(v_imag.is_neg()) @onlyOnCPUAndCUDA - @dtypes(*product(torch.testing.get_all_complex_dtypes(), torch.testing.get_all_dtypes())) + @dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) @suppress_warnings def test_set_real_imag(self, device, dtypes): x = torch.randn(10, dtype=dtypes[0], device=device) @@ -1215,8 +1219,8 @@ def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): self.assertEqual(np_res, torch_res) # TODO: are these view ops? - @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False) + - torch.testing.get_all_complex_dtypes())) + @dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + + get_all_complex_dtypes())) def test_atleast(self, device, dtype): self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) @@ -1252,7 +1256,7 @@ def test_broadcast_shapes(self, device): self.assertEqual(expected, actual) # Skip BFloat16 since numpy does not support it - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + @dtypes(*get_all_dtypes(include_bfloat16=False)) def test_broadcast_to(self, device, dtype): def can_broadcast(s0, s1): # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension @@ -1355,7 +1359,7 @@ def test_view(self, device): self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) - @dtypes(*torch.testing.get_all_dtypes()) + @dtypes(*get_all_dtypes()) def test_reshape_view_semantics(self, device, dtype): tensor = make_tensor((15, 4), device, dtype) target = (20, 3) @@ -1382,7 +1386,7 @@ def test_contiguous(self, device): @onlyOnCPUAndCUDA # Skip BFloat16 since numpy does not support it - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + @dtypes(*get_all_dtypes(include_bfloat16=False)) def test_tensor_split_sections(self, device, dtype): input_sizes = [ (0,), @@ -1413,7 +1417,7 @@ def test_tensor_split_sections(self, device, dtype): @onlyOnCPUAndCUDA # Skip BFloat16 since numpy does not support it - @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False)) + @dtypes(*get_all_dtypes(include_bfloat16=False)) def test_tensor_split_indices(self, device, dtype): input_sizes = [ (0,), @@ -1492,20 +1496,20 @@ def test_tensor_split_errors(self, device): def test_resize_all_dtypes_and_devices(self, device): shape = (2, 2) - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) x.resize_(shape) self.assertEqual(shape, x.shape) def test_resize_as_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) x.resize_as_(y) self.assertEqual(y.shape, x.shape) def test_view_all_dtypes_and_devices(self, device): - for dt in torch.testing.get_all_dtypes(): + for dt in get_all_dtypes(): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) self.assertEqual(x.view(6).shape, [6]) diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index 4fa64e75eceb4..a0f8328ec660b 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -34,7 +34,7 @@ def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): ref_result = F.linear(input_data, weight, bias) packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias) - torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) @given(input_size=st.integers(2, 32), weight_output_dim=st.integers(2, 64), @@ -49,7 +49,7 @@ def test_linear_1d_input(self, input_size, weight_output_dim, use_bias): ref_result = F.linear(input_data, weight, bias) packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) output_linearprepacked = torch.ops.prepacked.linear_clamp_run(input_data, packed_weight_bias) - torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) @given(batch_size=st.integers(0, 3), @@ -107,7 +107,7 @@ def test_conv2d(self, packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(weight, bias, strides, paddings, dilations, groups) xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) @given(batch_size=st.integers(1, 3), input_channels_per_group=st.integers(1, 32), @@ -174,7 +174,7 @@ def test_conv2d_transpose(self, output_paddings, dilations, groups) xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias) - torch.testing.assert_allclose(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3) @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." @@ -214,7 +214,7 @@ def forward(self, x): input_data = torch.rand(data_shape) ref_result = scripted_linear(input_data) output_linearprepacked = scripted_linear_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) # Serialize the modules and then deserialize input_data = torch.rand(data_shape) @@ -228,7 +228,7 @@ def forward(self, x): deserialized_linear_clamp_prepacked = torch.jit.load(buffer) ref_result = deserialized_linear(input_data) output_linearprepacked = deserialized_linear_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3) @given(batch_size=st.integers(0, 3), input_channels_per_group=st.integers(1, 32), @@ -309,7 +309,7 @@ def forward(self, x): weight, bias, strides, paddings, dilations, groups)) ref_result = scripted_conv2d(input_data) xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) # Serialize the modules and then deserialize input_data = torch.rand((batch_size, input_channels, height, width)) @@ -325,7 +325,7 @@ def forward(self, x): deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) ref_result = deserialized_conv2d(input_data) xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) @given(batch_size=st.integers(0, 3), input_channels_per_group=st.integers(1, 32), @@ -417,7 +417,7 @@ def forward(self, x): weight, bias, strides, paddings, output_paddings, dilations, groups)) ref_result = scripted_conv2d(input_data) xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) # Serialize the modules and then deserialize input_data = torch.rand((batch_size, input_channels, height, width)) @@ -433,7 +433,7 @@ def forward(self, x): deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) ref_result = deserialized_conv2d(input_data) xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) @given(batch_size=st.integers(0, 3), input_channels_per_group=st.integers(1, 32), @@ -549,7 +549,7 @@ def forward(self, x): groups)) ref_result = scripted_m(input_data) xnnpack_result = scripted_m_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) # Serialize the modules and then deserialize input_data = torch.rand((batch_size, input_channels, height, width)) @@ -564,7 +564,7 @@ def forward(self, x): deserialized_m_prepacked = torch.jit.load(buffer) ref_result = deserialized_m(input_data) xnnpack_result = deserialized_m_prepacked(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) @unittest.skipUnless(torch.backends.xnnpack.enabled, @@ -610,7 +610,7 @@ def validate_transformed_module( else: FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) xnnpack_result = deserialized_scripted_model(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) def test_linear(self): data_shape = [2, 3, 32] @@ -965,7 +965,7 @@ def validate_transform_conv1d_to_conv2d( else: FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) transformed_result = deserialized_scripted_model(input_data) - torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, transformed_result, rtol=1e-2, atol=1e-3) optimized_buffer = io.BytesIO() torch.jit.save(optimized_scripted_model, optimized_buffer) @@ -980,7 +980,7 @@ def validate_transform_conv1d_to_conv2d( else: FileCheck().check_count(pattern, v, exactly=True).run(deserialized_optimized_scripted_model.graph) xnnpack_result = deserialized_optimized_scripted_model(input_data) - torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) def test_conv1d_basic(self): diff --git a/third_party/breakpad b/third_party/breakpad new file mode 160000 index 0000000000000..469a80ee54947 --- /dev/null +++ b/third_party/breakpad @@ -0,0 +1 @@ +Subproject commit 469a80ee54947ad8d000d33a615f1a199165a711 diff --git a/third_party/cuda.BUILD b/third_party/cuda.BUILD new file mode 100644 index 0000000000000..0c58b34a52e74 --- /dev/null +++ b/third_party/cuda.BUILD @@ -0,0 +1,43 @@ +""" +Collect all the CUDA stuff from @local_config_cuda in a single target +for convenience. +""" + +cc_library( + name = "cuda", + visibility = ["//visibility:public"], + deps = [ + "@local_config_cuda//cuda:cublas", + "@local_config_cuda//cuda:cuda_driver", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart", + "@local_config_cuda//cuda:cufft", + "@local_config_cuda//cuda:curand", + ], +) + +cc_library( + name = "cupti", + deps = [ + "@local_config_cuda//cuda:cupti_headers", + "@local_config_cuda//cuda:cupti_link", + ], +) + +[ + alias( + name = lib, + actual = "@local_config_cuda//cuda:{}".format(lib), + visibility = ["//visibility:public"], + ) + for lib in [ + "cublas", + "cufft", + "cusolver", + "cusparse", + "curand", + "nvrtc", + "cuda_driver", + "nvToolsExt", + ] +] diff --git a/third_party/fbgemm b/third_party/fbgemm index 10ec0d3388579..7b49986d74a66 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 10ec0d33885795e6b4cc9a17896ee3f25b48fa8e +Subproject commit 7b49986d74a6666fa6913bd9b461ebebb2cad476 diff --git a/third_party/protobuf b/third_party/protobuf index d0bfd5221182d..d1eca4e4b421c 160000 --- a/third_party/protobuf +++ b/third_party/protobuf @@ -1 +1 @@ -Subproject commit d0bfd5221182da1a7cc280f3337b5e41a89539cf +Subproject commit d1eca4e4b421cd2997495c4b4e65cea6be4e9b8a diff --git a/third_party/tensorflow_cuda_bazel_build/BUILD b/third_party/tensorflow_cuda_bazel_build/BUILD new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/third_party/tensorflow_cuda_bazel_build/README.md b/third_party/tensorflow_cuda_bazel_build/README.md new file mode 100644 index 0000000000000..439e195d8e44e --- /dev/null +++ b/third_party/tensorflow_cuda_bazel_build/README.md @@ -0,0 +1,5 @@ +# Config for CUDA + +This is a checked-in copy of the auto-generated config for building CUDA code with bazel. The content of this folder was generated from https://github.com/tensorflow/tensorflow `./configure` execution and then edited manually to fit the pytorch needs. + +The LICENSE for the TensorFlow project is APACHE 2. The full LICENSE file could be found here https://github.com/tensorflow/tensorflow/blob/master/LICENSE. diff --git a/third_party/tensorflow_cuda_bazel_build/WORKSPACE b/third_party/tensorflow_cuda_bazel_build/WORKSPACE new file mode 100644 index 0000000000000..59369ce679c14 --- /dev/null +++ b/third_party/tensorflow_cuda_bazel_build/WORKSPACE @@ -0,0 +1 @@ +workspace(name = "local_config_cuda") diff --git a/third_party/tensorflow_cuda_bazel_build/cuda/BUILD b/third_party/tensorflow_cuda_bazel_build/cuda/BUILD new file mode 100755 index 0000000000000..f7271af2750b8 --- /dev/null +++ b/third_party/tensorflow_cuda_bazel_build/cuda/BUILD @@ -0,0 +1,451 @@ +licenses([ + "restricted", + "reciprocal", + "notice", +]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, +) + +cc_library( + name = "cuda_headers", + hdrs = [ + ":cuda-include", + ":cudnn-include", + ], + includes = [ + ".", + "include", + ], +) + +cc_library( + name = "cudnn_headers", + hdrs = [ + ":cudnn-include", + ], + includes = [ + ".", + "include", + ], +) + +cc_library( + name = "cudart_static", + linkopts = [ + "-L/usr/local/cuda/lib64", + ], +) + +cc_library( + name = "cuda_driver", + linkopts = ["-lcuda"], + deps = [":linker_search_path"], +) + +# Provides the RPATH for Nvidia-less sytems to be able to run binaries linked to libcuda. +cc_library( + name = "driver_stub_runtime", + linkopts = [ + "-Wl,-rpath,/usr/local/cuda/lib64/stubs", + ], + deps = [":cuda_driver"], +) + +cc_library( + name = "linker_search_path", + linkopts = [ + "-L/usr/local/cuda/lib64", + "-L/usr/local/cuda/lib64/stubs", + "-Wl,-rpath-link,/usr/local/cuda/lib64", + "-Wl,-rpath-link,/usr/local/cuda/lib64/stubs", + ], +) + +[ + cc_library( + name = libname, + linkopts = ["-l" + libname] + (["-lgomp"] if (libname == "cusolver") else []), + linkstatic = True, + deps = [":linker_search_path"], + ) + for libname in [ + "cublas", + "cudart", + "cudnn", + "cufft", + "curand", + "cusolver", + "cusparse", + "nvrtc", + "nvToolsExt", + ] +] + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ":nvToolsExt", + ], +) + +# NVIDIA Performance Primitives (http://docs.nvidia.com/cuda/npp/modules.html)) +# used by OpenCV +cc_library( + name = "nppi", + linkopts = [ + "-lnppc", + "-lnppial", + "-lnppicom", + "-lnppidei", + "-lnppif", + "-lnppig", + "-lnppim", + "-lnppist", + "-lnppitc", + "-lnpps", + ], + linkstatic = True, + deps = [":linker_search_path"], +) + +# NVIDIA Management Library +cc_library( + name = "nvml", + linkopts = [ + "-lnvidia-ml", + "-Wl,-rpath,/usr/lib/nvidia-410", + "-Wl,-rpath,/usr/lib/nvidia-390", + "-Wl,-rpath,/usr/lib/nvidia-387", + "-Wl,-rpath,/usr/lib/nvidia-384", + ], + deps = [":linker_search_path"], +) + +cc_library( + name = "cupti_headers", + hdrs = [ + ":cuda-extras", + ], + includes = [ + ".", + "extras/CUPTI/include/", + ], +) + +# cupti .so exposed at linktime +cc_library( + name = "cupti_link", + linkopts = [ + "-L/usr/local/cuda/extras/CUPTI/lib64", + "-lcupti", + ], +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +CUDA_INCLUDES_FILES = [ + "include/builtin_types.h", + "include/channel_descriptor.h", + "include/CL/cl_egl.h", + "include/CL/cl_ext.h", + "include/CL/cl_gl_ext.h", + "include/CL/cl_gl.h", + "include/CL/cl.h", + "include/CL/cl.hpp", + "include/CL/cl_platform.h", + "include/CL/opencl.h", + "include/common_functions.h", + "include/cooperative_groups.h", + "include/cooperative_groups_helpers.h", + "include/crt/common_functions.h", + "include/crt/device_double_functions.h", + "include/crt/device_double_functions.hpp", + "include/crt/device_functions.h", + "include/crt/device_functions.hpp", + "include/crt/func_macro.h", + "include/crt/host_config.h", + "include/crt/host_defines.h", + "include/crt/host_runtime.h", + "include/crt/math_functions.h", + "include/crt/math_functions.hpp", + "include/crt/mma.h", + "include/crt/mma.hpp", + "include/crt/nvfunctional", + "include/crt/sm_70_rt.h", + "include/crt/sm_70_rt.hpp", + "include/crt/storage_class.h", + # TODO: figure out why on a CI machine with CUDA 10.2 it's not present + # "include/cublas_api.h", + # "include/cublas.h", + # "include/cublas_v2.h", + # "include/cublasXt.h", + "include/cuComplex.h", + "include/cuda_device_runtime_api.h", + "include/cudaEGL.h", + "include/cuda_egl_interop.h", + "include/cuda_fp16.h", + "include/cuda_fp16.hpp", + "include/cudaGL.h", + "include/cuda_gl_interop.h", + "include/cuda.h", + "include/cudalibxt.h", + "include/cuda_occupancy.h", + "include/cuda_profiler_api.h", + "include/cudaProfiler.h", + "include/cudart_platform.h", + "include/cuda_runtime_api.h", + "include/cuda_runtime.h", + "include/cuda_surface_types.h", + "include/cuda_texture_types.h", + "include/cudaVDPAU.h", + "include/cuda_vdpau_interop.h", + "include/cufft.h", + "include/cufftw.h", + "include/cufftXt.h", + "include/curand_discrete2.h", + "include/curand_discrete.h", + "include/curand_globals.h", + "include/curand.h", + "include/curand_kernel.h", + "include/curand_lognormal.h", + "include/curand_mrg32k3a.h", + "include/curand_mtgp32dc_p_11213.h", + "include/curand_mtgp32.h", + "include/curand_mtgp32_host.h", + "include/curand_mtgp32_kernel.h", + "include/curand_normal.h", + "include/curand_normal_static.h", + "include/curand_philox4x32_x.h", + "include/curand_poisson.h", + "include/curand_precalc.h", + "include/curand_uniform.h", + "include/cusolver_common.h", + "include/cusolverDn.h", + "include/cusolverRf.h", + "include/cusolverSp.h", + "include/cusolverSp_LOWLEVEL_PREVIEW.h", + "include/cusparse.h", + "include/cusparse_v2.h", + "include/device_atomic_functions.h", + "include/device_atomic_functions.hpp", + "include/device_double_functions.h", + "include/device_functions.h", + "include/device_launch_parameters.h", + "include/device_types.h", + "include/driver_functions.h", + "include/driver_types.h", + "include/fatBinaryCtl.h", + "include/fatbinary.h", + "include/host_config.h", + "include/host_defines.h", + "include/library_types.h", + "include/math_constants.h", + "include/math_functions.h", + "include/mma.h", + "include/nppcore.h", + "include/nppdefs.h", + "include/npp.h", + "include/nppi_arithmetic_and_logical_operations.h", + "include/nppi_color_conversion.h", + "include/nppi_compression_functions.h", + "include/nppi_computer_vision.h", + "include/nppi_data_exchange_and_initialization.h", + "include/nppi_filtering_functions.h", + "include/nppi_geometry_transforms.h", + "include/nppi.h", + "include/nppi_linear_transforms.h", + "include/nppi_morphological_operations.h", + "include/nppi_statistics_functions.h", + "include/nppi_support_functions.h", + "include/nppi_threshold_and_compare_operations.h", + "include/npps_arithmetic_and_logical_operations.h", + "include/npps_conversion_functions.h", + "include/npps_filtering_functions.h", + "include/npps.h", + "include/npps_initialization.h", + "include/npps_statistics_functions.h", + "include/npps_support_functions.h", + # Note: CUDA 10.0 only + # "include/nppversion.h", + # TODO: figure out why on a CI machine with CUDA 10.2 it's not present + # "include/nvblas.h", + "include/nvfunctional", + "include/nvgraph.h", + "include/nvjpeg.h", + "include/nvml.h", + "include/nvrtc.h", + "include/nvToolsExtCuda.h", + "include/nvToolsExtCudaRt.h", + "include/nvToolsExt.h", + "include/nvToolsExtMeta.h", + "include/nvToolsExtSync.h", + "include/nvtx3/nvToolsExtCuda.h", + "include/nvtx3/nvToolsExtCudaRt.h", + "include/nvtx3/nvToolsExt.h", + "include/nvtx3/nvToolsExtOpenCL.h", + "include/nvtx3/nvToolsExtSync.h", + "include/nvtx3/nvtxDetail/nvtxImplCore.h", + "include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h", + "include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h", + "include/nvtx3/nvtxDetail/nvtxImpl.h", + "include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h", + "include/nvtx3/nvtxDetail/nvtxImplSync_v3.h", + "include/nvtx3/nvtxDetail/nvtxInitDecls.h", + "include/nvtx3/nvtxDetail/nvtxInitDefs.h", + "include/nvtx3/nvtxDetail/nvtxInit.h", + "include/nvtx3/nvtxDetail/nvtxLinkOnce.h", + "include/nvtx3/nvtxDetail/nvtxTypes.h", + "include/sm_20_atomic_functions.h", + "include/sm_20_atomic_functions.hpp", + "include/sm_20_intrinsics.h", + "include/sm_20_intrinsics.hpp", + "include/sm_30_intrinsics.h", + "include/sm_30_intrinsics.hpp", + "include/sm_32_atomic_functions.h", + "include/sm_32_atomic_functions.hpp", + "include/sm_32_intrinsics.h", + "include/sm_32_intrinsics.hpp", + "include/sm_35_atomic_functions.h", + "include/sm_35_intrinsics.h", + "include/sm_60_atomic_functions.h", + "include/sm_60_atomic_functions.hpp", + "include/sm_61_intrinsics.h", + "include/sm_61_intrinsics.hpp", + # CUDA 10.0 only + # "include/sobol_direction_vectors.h", + "include/surface_functions.h", + "include/surface_functions.hpp", + "include/surface_indirect_functions.h", + "include/surface_indirect_functions.hpp", + "include/surface_types.h", + "include/texture_fetch_functions.h", + "include/texture_fetch_functions.hpp", + "include/texture_indirect_functions.h", + "include/texture_indirect_functions.hpp", + "include/texture_types.h", + "include/vector_functions.h", + "include/vector_functions.hpp", + "include/vector_types.h", +] + +genrule( + name = "cuda-include", + outs = CUDA_INCLUDES_FILES, + cmd = " && ".join([ + "ln -s /usr/local/cuda/{relpath} $(@D)/{relpath}".format(relpath = p) + for p in CUDA_INCLUDES_FILES + ]), + local = True, + tags = ["no-cache"], +) + +CUDA_NVVM_FILES = [ + "nvvm/bin/cicc", + "nvvm/include/nvvm.h", + "nvvm/lib64/libnvvm.so", + "nvvm/lib64/libnvvm.so.3", + "nvvm/lib64/libnvvm.so.3.3.0", + "nvvm/libdevice/libdevice.10.bc", +] + +genrule( + name = "cuda-nvvm", + outs = CUDA_NVVM_FILES, + cmd = " && ".join([ + "ln -s /usr/local/cuda/{relpath} $(@D)/{relpath}".format(relpath = p) + for p in CUDA_NVVM_FILES + ]), + local = True, + tags = ["no-cache"], +) + +CUDA_EXTRAS_FILES = [ + "extras/CUPTI/include/cuda_stdint.h", + "extras/CUPTI/include/cupti.h", + "extras/CUPTI/include/cupti_activity.h", + "extras/CUPTI/include/cupti_callbacks.h", + "extras/CUPTI/include/cupti_driver_cbid.h", + "extras/CUPTI/include/cupti_events.h", + "extras/CUPTI/include/cupti_metrics.h", + "extras/CUPTI/include/cupti_nvtx_cbid.h", + "extras/CUPTI/include/cupti_result.h", + "extras/CUPTI/include/cupti_runtime_cbid.h", + "extras/CUPTI/include/cupti_version.h", + "extras/CUPTI/include/generated_cuda_gl_interop_meta.h", + "extras/CUPTI/include/generated_cuda_meta.h", + "extras/CUPTI/include/generated_cuda_runtime_api_meta.h", + "extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h", + "extras/CUPTI/include/generated_cudaGL_meta.h", + "extras/CUPTI/include/generated_cudaVDPAU_meta.h", + "extras/CUPTI/include/generated_nvtx_meta.h", + "extras/CUPTI/include/GL/gl.h", + "extras/CUPTI/include/GL/glew.h", + "extras/CUPTI/include/GL/glext.h", + "extras/CUPTI/include/GL/glu.h", + "extras/CUPTI/include/GL/glut.h", + "extras/CUPTI/include/GL/glx.h", + "extras/CUPTI/include/GL/glxext.h", + "extras/CUPTI/include/GL/wglew.h", + "extras/CUPTI/include/GL/wglext.h", + "extras/CUPTI/include/openacc/cupti_openacc.h", +] + +genrule( + name = "cuda-extras", + outs = CUDA_EXTRAS_FILES, + cmd = " && ".join([ + "ln -s /usr/local/cuda/{relpath} $(@D)/{relpath}".format(relpath = p) + for p in CUDA_EXTRAS_FILES + ]), + local = True, + tags = ["no-cache"], +) + +genrule( + name = "cudnn-include", + outs = [ + "include/cudnn.h", + ], + cmd = """ + ln -s /usr/include/cudnn.h $(@D)/cudnn.h""", + local = True, + tags = ["no-cache"], +) + diff --git a/third_party/tensorpipe b/third_party/tensorpipe index e45b2338d0a31..1cd0ac3e4ce51 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit e45b2338d0a31192a7e413f3fbbfa7fd90504a37 +Subproject commit 1cd0ac3e4ce5144ee4ea2545741182c76fba6cf2 diff --git a/tools/README.md b/tools/README.md index a28affa5f30aa..e4aba38afd851 100644 --- a/tools/README.md +++ b/tools/README.md @@ -15,10 +15,6 @@ Modern infrastructure: to import arbitrary Python files in a script, without having to add them to the PYTHONPATH first. -Legacy infrastructure (we should kill this): -* [cwrap](cwrap) - Implementation of legacy code generation for THNN/THCUNN. - This is used by nnwrap. - Build system pieces: * [setup_helpers](setup_helpers) - Helper code for searching for diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 8cfecda82e328..70f7e7a83e1ec 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -81,12 +81,10 @@ "aten/src/ATen/native/sparse/cuda/*", "aten/src/ATen/native/quantized/cuda/*", "aten/src/THC/*", - "aten/src/THCUNN/*", "aten/src/ATen/test/*", # CMakeLists.txt isn't processed by default, but there are a few # we do want to handle, so explicitly specify them "aten/src/THC/CMakeLists.txt", - "aten/src/THCUNN/CMakeLists.txt", "torch/*", "tools/autograd/templates/python_variable_methods.cpp", ] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index b52b69018e177..641471ebc8f06 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -801,11 +801,11 @@ - name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor self: grad * at::xlogy((self != 0), other) - other: grad * self / other + other: grad * at::where(other.isnan() | (self != 0), self / other, zeros_like(other)) result: self_t * at::xlogy((self_p != 0), other_p) + other_t * self_p / other_p - name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor - other: grad * self / other + other: grad * at::where(other.isnan() | (!self.equal(0)), self / other, zeros_like(other)) result: auto_element_wise - name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor @@ -1604,10 +1604,6 @@ self: soft_margin_loss_backward(grad, self, target, reduction) - name: relu(Tensor self) -> Tensor - self: threshold_backward(grad, self, 0) - -# NB: `output` instead of `self` saves memory. It avoids saving a copy of self. -- name: relu_(Tensor(a!) self) -> Tensor(a!) self: threshold_backward(grad, result, 0) - name: silu(Tensor self) -> Tensor diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 7d852aded47a9..08136ab54bfcc 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -479,7 +479,7 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: body: List[str] = [] if uses_single_grad(info): - body.append('auto& grad = grads[0];') + body.append('const auto& grad = grads[0];') def emit_derivative( derivative: Derivative, diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 6c42bec1e5d12..524cca262f4f2 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -124,6 +124,10 @@ ); """) +AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate("""\ +m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback()); +""") + INPLACE_REDISPATCH = CodeTemplate("""\ { at::AutoDispatchBelowADInplaceOrView guard; diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index d1fb70c8abed3..f61d3d0c0709c 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -164,9 +164,12 @@ def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_pat create_python_bindings( fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True) + # NOTE: num_shards here must be synced with gatherTorchFunctions in + # torch/csrc/autograd/python_torch_functions_manual.cpp functions = load_signatures(native_functions, deprecated_yaml_path, method=False) - create_python_bindings( - fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False) + create_python_bindings_sharded( + fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', + method=False, num_shards=3) create_python_bindings( fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False) @@ -180,6 +183,16 @@ def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_pat create_python_bindings( fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False) +def group_filter_overloads( + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool] +) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]: + grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + return grouped + def create_python_bindings( fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair], @@ -194,10 +207,7 @@ def create_python_bindings( py_method_defs: List[str] = [] py_forwards: List[str] = [] - grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) - for pair in pairs: - if pred(pair.function): - grouped[pair.function.func.name.name].append(pair) + grouped = group_filter_overloads(pairs, pred) for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] @@ -212,6 +222,44 @@ def create_python_bindings( 'py_method_defs': py_method_defs, }) +def create_python_bindings_sharded( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: Optional[str], + filename: str, + *, + method: bool, + num_shards: int +) -> None: + """Generates Python bindings to ATen functions""" + grouped = group_filter_overloads(pairs, pred) + + def key_func(kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]) -> str: + return str(kv[0]) + + def env_func( + kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] + ) -> Dict[str, List[str]]: + return { + 'py_forwards': list(forward_decls(kv[0], kv[1], method=method)), + 'py_methods': [method_impl(kv[0], module, kv[1], method=method)], + 'py_method_defs': [method_def(kv[0], module, kv[1], method=method)], + } + + fm.write_sharded( + filename, + grouped.items(), + base_env={ + 'generated_comment': + '@' + f'generated from {fm.template_dir}/{filename}', + }, + key_fn=key_func, + env_callable=env_func, + num_shards=num_shards, + sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'} + ) + def load_signatures( native_functions: List[NativeFunction], deprecated_yaml_path: str, diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index a64f7341e281c..e3f4d5553c34f 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -1,5 +1,8 @@ # Generates VariableType.h/cpp # +# **If any changes are being made to the VariableType codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# # VariableType is a subclass of at::Type that provides the binding code # necessary to provide a differentiable version of ATen operators. There are a # number of different things we could mean: @@ -30,7 +33,8 @@ from .gen_inplace_or_view_type import ( get_view_info, is_tensor_type, is_tensor_list_type, unpack_args, get_base_name, use_derived, modifies_arguments, WRAPPER_REGISTRATION, TMP_VAR, METHOD_DEFINITION, - ASSIGN_RETURN_VALUE, gen_formals, ALL_VIEW_FUNCTIONS, unpacked_name + ASSIGN_RETURN_VALUE, gen_formals, ALL_VIEW_FUNCTIONS, unpacked_name, + AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION ) from tools.codegen.api.types import (Binding, DispatcherSignature, BaseCType, intArrayRefT, @@ -370,7 +374,7 @@ def gen_variable_type( """ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write('VariableType.h', lambda: { - 'generated_comment': f'@generated from {template_path}/VariableType.h' + 'generated_comment': "@" f'generated from {template_path}/VariableType.h' }) # NOTE: see Note [Sharded File] at the top of the VariableType.cpp @@ -381,7 +385,7 @@ def gen_variable_type( key_fn=lambda fn: cpp.name(fn.func.func), base_env={ 'generated_comment': - f'@generated from {template_path}/VariableType.cpp', + "@" f'generated from {template_path}/VariableType.cpp', }, env_callable=gen_variable_type_func, num_shards=5, @@ -404,13 +408,39 @@ def gen_variable_type_func( name = cpp.name(f.func) formals = gen_formals(f) - type_definition = METHOD_DEFINITION.substitute( - return_type=cpp.returns_type(f.func.returns).cpp_type(), - type_wrapper_name=type_wrapper_name(f), - type_definition_body=emit_body(fn), - formals=formals, - ) - wrapper_registration = gen_wrapper_registration(f) + if fn.info is None and not get_base_name(f) in RESET_GRAD_ACCUMULATOR \ + and not get_base_name(f) in DONT_REQUIRE_DERIVATIVE \ + and len(gen_differentiable_outputs(fn)) > 0 \ + and not cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE \ + and not type_wrapper_name(f) in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT \ + and not type_wrapper_name(f) in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: + # NOTE: [ Registering AutogradNotImplemented boxed kernel ] + # + # When there is no derivatives.yaml entry, we register a generic boxed + # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward + # proceeds as usual but an error is properly produced on backward. + # TODO: it would be nice to not have these special cases + # + # There are several cases where still let codegen handle it: + # 1) ops that need to reset grad accumulator (we let codegen handle this case + # because) the list is (currently) only accessible in Python. + # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes + # autograd a fallthrough with NDEBUG checks. This can be useful for when all + # outputs are integral. + # 3) When there are no differentiable outputs. This is similar to (2). + # 4) There are certain ops where we skip certain NDEBUG checks. this is similar + # to (1). + type_definition = "" + wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name) + else: + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + type_definition_body=emit_body(fn), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f) # See Note [Manual Backend kernels] assert (name in MANUAL_BACKEND) == f.manual_kernel_registration diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 3ff11f4d18691..8a5904b732918 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -2,9 +2,9 @@ # # Each autograd function is represented by `DifferentiabilityInfo` containing # a list of `Derivative`. See `tools.codegen.api.autograd` for the data models. -from collections import defaultdict, Counter +from collections import defaultdict import re -from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional +from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional import yaml from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo, @@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque assert str(function.func) not in functions_by_schema functions_by_schema[str(function.func)] = function + # Keep track of how many of which ops we've seen so we can + # disambiguate them with a numeric suffix. + op_counter = Counter[str]() + infos = [ - create_differentiability_info(defn, functions_by_signature, functions_by_schema) + create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter) for defn in definitions] - # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate - # step. We only assign op names to those with differentiable args, and only append suffix to - # duplicated op names. This can be simplified if the first of the duplicates can be named - # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons. - op_names = create_op_names(infos) - res = [ - DifferentiabilityInfo( - name=info.name, - func=info.func, - op=op_name, - derivatives=info.derivatives, - forward_derivatives=info.forward_derivatives, - all_saved_inputs=info.all_saved_inputs, - all_saved_outputs=info.all_saved_outputs, - args_with_derivatives=info.args_with_derivatives, - non_differentiable_arg_names=info.non_differentiable_arg_names, - output_differentiability=info.output_differentiability, - output_differentiability_conditions=info.output_differentiability_conditions, - ) - for info, op_name in zip(infos, op_names)] - - _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] @@ -279,6 +262,7 @@ def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], + op_counter: Counter[str], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" @@ -424,10 +408,17 @@ def set_up_derivatives(f: NativeFunction) -> Tuple[ derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical) + # only assign an op name if we are actually going to calculate a derivative + op = None + if args_with_derivatives: + op_prefix = _create_op_prefix(defn_name) + op = f'{op_prefix}{op_counter[op_prefix]}' + op_counter[op_prefix] += 1 + return DifferentiabilityInfo( name=defn_name, func=canonical, - op=None, + op=op, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]), @@ -566,35 +557,22 @@ def repl(m: Match[str]) -> str: return formula, tuple(saved) -def create_op_name(info: DifferentiabilityInfo) -> Optional[str]: - # only assign an op name if we are actually going to calculate a derivative - if not info.args_with_derivatives: - return None - name = info.name +def _create_op_prefix(name: str) -> str: + """Takes a native function name converts to a op prefix name. + + Note that the "name" parameter must be the native function name + without the optional variant suffix, so "add" instead of + "add.out". + + OP names correspond to classes, hence the change to title case. + + Example:: + >>> _create_op_prefix('add') + 'AddBackward' + """ camel_case = ''.join([p.title() for p in name.split('_')]) return (camel_case + 'Backward').replace('ForwardBackward', 'Backward') -def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]: - names = list(map(create_op_name, infos)) - dups = set(item for item, count in Counter(names).items() if count > 1) - - # de-duplicate operation names - # you end up with something like: - # AddBackward0 - # AddBackward1 - # one for each overload - counter: Dict[str, int] = Counter() - dedup: List[Optional[str]] = [] - for name in names: - if name is None: - # Keep a placeholder - dedup.append(None) - elif name in dups: - dedup.append(f'{name}{counter[name]}') - counter[name] += 1 - else: - dedup.append(name) - return dedup def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: seen: Set[str] = set() diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 1ff3604ec21ea..605a700fb1a47 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,4 +1,5 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/VariableType.h" #include "torch/csrc/autograd/FunctionsManual.h" #include diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index fc8ffa5799c11..333e8a0d7ada5 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -7,6 +7,7 @@ #include #include +#include #include // for size_t #include // for function diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 9e02036639516..b45b5f298716b 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -7,7 +7,6 @@ // and also copied into 'torch' module. #include -#include // Undefine the copysign macro so that at::copysign works as intended with MSVC // https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 @@ -15,6 +14,7 @@ #undef copysign #endif // _MSC_VER +#include "torch/csrc/autograd/python_torch_functions.h" #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/Dtype.h" @@ -34,7 +34,6 @@ #include -#include #include #include #include @@ -59,767 +58,28 @@ using at::ArrayRef; using torch::utils::check_out_type_matches; using namespace torch::autograd::utils; -namespace torch { namespace autograd { - -static PyObject* THPVariableFunctionsModule = NULL; - -inline Tensor dispatch_arange(const Scalar& end, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::arange_out(result, end); -} - -inline Tensor dispatch_arange(const Scalar& end, const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::arange(end, options); -} - -inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::arange_out(result, start, end, step); -} - -inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::arange(start, end, step, options); -} - -static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); - - ParsedArgs<9> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - if (r.idx == 0) { - if (r.isNone(1)) { - auto end = r.scalar(0); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - c10::optional scalarType = r.scalartypeOptional(2); - const auto options = TensorOptions() - .dtype(scalarType) - .device(r.device(4)) - .layout(r.layout(3)) - .requires_grad(r.toBool(6)) - .pinned_memory(r.toBool(5)); - return wrap(dispatch_arange(end, options)); - } else { - TORCH_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), - r.device(4), r.isNone(4)); - return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6))); - } - } else if (r.idx == 1) { - if (r.isNone(3)) { - auto start = r.scalar(0); - auto end = r.scalar(1); - auto step = r.scalar(2); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - c10::optional scalarType = r.scalartypeOptional(4); - const auto options = TensorOptions() - .dtype(scalarType) - .device(r.device(6)) - .layout(r.layout(5)) - .requires_grad(r.toBool(8)) - .pinned_memory(r.toBool(7)); - return wrap(dispatch_arange(start, end, step, options)); - } else { - TORCH_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), - r.device(6), r.isNone(6)); - return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8))); - } - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(result)); - return at::range_out(result, start, end, step); -} - -inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - DeviceGuard device_guard(options.device()); - return torch::range(start, end, step, options); -} - -static PyObject * THPVariable_range(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - }); - - ParsedArgs<8> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if (r.idx == 0) { - auto ret = PyErr_WarnEx( - PyExc_UserWarning, - "torch.range is deprecated and will be removed in a future release " - "because its behavior is inconsistent with Python's range builtin. " - "Instead, use torch.arange, which produces values in [start, end).", - 1); - if (ret != 0) throw python_error(); - if (r.isNone(3)) { - const auto options = TensorOptions() - .dtype(r.scalartype(4)) - .device(r.device(6)) - .layout(r.layout(5)) - .requires_grad(r.toBool(7)); - return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options)); - } else { - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), - r.layout(5), r.device(6), r.isNone(6)); - return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7))); - } - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -inline Tensor dispatch_full( - IntArrayRef size, - const Scalar& fill_val, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return at::full(size, fill_val, options); -} - -inline Tensor dispatch_full( - IntArrayRef size, - const Scalar& fill_val, - c10::optional names, - const TensorOptions& options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return at::full(size, fill_val, names, options); -} - -inline Tensor dispatch_full( - IntArrayRef size, - const Scalar& fill_val, - Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::full_out(result, size, fill_val); -} - -static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) { - HANDLE_TH_ERRORS - - static PythonArgParser parser({ - "full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - "full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); - - // Acquires (common) arguments - ParsedArgs<8> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); +// NOTE: See [Sharded File] comment in VariableType - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - auto size = r.intlist(0); - auto fill_val = r.scalar(1); - const auto options = TensorOptions{} - .dtype(r.scalartypeOptional(3)) - .layout(r.layout(4)) - .device(r.device(5)) - .pinned_memory(r.toBool(6)); - - if (r.idx == 0) { - // full - if (r.isNone(2)) { - return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7))); - } - - // full.out - // Validates out tensor and other kwargs - auto result = r.tensor(2); - TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible"); - check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4), - r.device(5), r.isNone(5)); - - return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7))); - } else if (r.idx == 1) { - // full.names - if (r.isNone(2)) { - return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7))); - } - - // Converts from c10::optional to c10::optional - auto raw_names = r.toDimnameListOptional(2); - c10::optional names(*raw_names); - return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7))); - } - - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, high, size, generator); -} -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(high, size, generator, options); -} -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, high, size); -} -inline Tensor dispatch_randint(int64_t high, IntArrayRef size, const TensorOptions & options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(high, size, options); -} -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, low, high, size, generator); -} -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(low, high, size, generator, options); -} -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, Tensor result) { - pybind11::gil_scoped_release no_gil; - return at::randint_out(result, low, high, size); -} -inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, const TensorOptions & options) { - torch::utils::maybe_initialize_cuda(options); - pybind11::gil_scoped_release no_gil; - return torch::randint(low, high, size, options); -} - -static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - "randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - }, /*traceable=*/false); - - ParsedArgs<9> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if(r.has_torch_function()) { - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - if (r.idx == 0) { - if (r.isNone(3)) { - auto high = r.toInt64(0); - auto size = r.intlist(1); - auto generator = r.generator(2); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - auto dtype = r.scalartypeWithDefault(4, at::ScalarType::Long); - auto device = r.device(6); - const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(5)) - .requires_grad(r.toBool(7)); - return wrap(dispatch_randint(high, size, generator, options)); - } else { - check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), - r.layout(5), r.device(6), r.isNone(6)); - return wrap(dispatch_randint(r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)).set_requires_grad(r.toBool(7))); - } - } else if (r.idx == 1) { - if (r.isNone(4)) { - auto low = r.toInt64(0); - auto high = r.toInt64(1); - auto size = r.intlist(2); - auto generator = r.generator(3); - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - auto dtype = r.scalartypeWithDefault(5, at::ScalarType::Long); - auto device = r.device(7); - const auto options = TensorOptions() - .dtype(dtype) - .device(device) - .layout(r.layout(6)) - .requires_grad(r.toBool(8)); - return wrap(dispatch_randint(low, high, size, generator, options)); - } else { - check_out_type_matches(r.tensor(4), r.scalartype(5), r.isNone(5), - r.layout(6), r.device(7), r.isNone(7)); - return wrap(dispatch_randint(r.toInt64(0), r.toInt64(1), r.intlist(2), r.generator(3), r.tensor(4)).set_requires_grad(r.toBool(8))); - } - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -// implemented on python object to allow torch.as_tensor to be constructed with arbitrarily nested -// python objects - list, tuple, np array, scalar, etc. -static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -// implemented on python object here because PyObject currently not natively declarable -// See: ATen/native/README.md for more context -static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg)); - END_HANDLE_TH_ERRORS -} - -static Tensor dispatch_nonzero(const Tensor & self) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(self)); - return self.nonzero(); -} - -static Tensor dispatch_nonzero(const Tensor & self, Tensor out) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(self)); - return at::nonzero_out(out, self); -} - -static std::vector dispatch_nonzero_numpy(const Tensor & self) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(self)); - return self.nonzero_numpy(); -} - -static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs); - -static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -// implemented on python object to allow torch.tensor to be constructed with arbitrarily nested -// python objects - list, tuple, np array, scalar, etc. -static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "get_device(Tensor input)", - }, /*traceable=*/false); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if (r.idx == 0) { - return wrap(r.tensor(0).get_device()); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)", - }, /*traceable=*/false); - - PyObject* ret = nullptr; - ParsedArgs<5> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if (r.idx == 0) { - auto buffer = r.pyobject(0); - auto dtype = r.scalartype(1); - auto count = r.toInt64(2); - auto offset = r.toInt64(3); - auto requires_grad = r.toBool(4); - - auto elsize = at::elementSize(dtype); - size_t actual_count = 0; - Py_buffer view; - - TORCH_CHECK_VALUE( - PyObject_CheckBuffer(buffer) != 0, - "object does not implement Python buffer protocol."); - - if (PyObject_GetBuffer(buffer, &view, PyBUF_WRITABLE) < 0) { - TORCH_CHECK( - PyObject_GetBuffer(buffer, &view, PyBUF_SIMPLE) >= 0, - "could not retrieve buffer from object"); - TORCH_WARN_ONCE( - "The given buffer is not writable, and PyTorch does " - "not support non-writable tensors. This means you can write to the " - "underlying (supposedly non-writable) buffer using the tensor. " - "You may want to copy the buffer to protect its data or make it writable " - "before converting it to a tensor. This type of warning will be " - "suppressed for the rest of this program."); - PyErr_Clear(); - } - - Py_INCREF(view.obj); - THPObjectPtr obj(view.obj); - - auto len = view.len; - auto buf = view.buf; - PyBuffer_Release(&view); - - TORCH_CHECK_VALUE( - len > 0 && count != 0, - "both buffer length (", len, ") and count (", count, ") must not be 0"); - TORCH_CHECK_VALUE( - offset >= 0 && offset < len, - "offset (", offset, " bytes) must be non-negative and no greater than " - "buffer length (", len, " bytes) minus 1"); - TORCH_CHECK_VALUE( - count > 0 || (len - offset) % elsize == 0, - "buffer length (", len - offset, " bytes) after offset (", offset, " bytes) " - "must be a multiple of element size (", elsize, ")"); - - if (count < 0) { - actual_count = (len - offset) / elsize; - } else { - actual_count = static_cast(count); - } - - TORCH_CHECK_VALUE( - static_cast(offset) + actual_count * elsize <= len, - "requested buffer length (", actual_count, " * ", elsize, " bytes) " - "after offset (", offset, " bytes) must not be greater than actual " - "buffer length (", len, " bytes)"); - - auto offset_buf = static_cast(buf) + offset; - auto options = TensorOptions() - .dtype(dtype) - .device(c10::kCPU); - - auto tensor = at::for_blob(offset_buf, static_cast(actual_count)) - .options(options) - .deleter([obj = obj.release()](void*) { - pybind11::gil_scoped_acquire gil; - Py_DECREF(obj); - }) - .make_tensor(); - tensor.set_requires_grad(requires_grad); - ret = wrap(tensor); - } - - return ret; - - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs); - -// linspace -static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "linspace(Scalar start, Scalar end, int64_t? steps=None, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); - - ParsedArgs<9> parsed_args; - auto _r = parser.parse(nullptr, args, kwargs, parsed_args); - if(_r.has_torch_function()) { - return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); - } - if (_r.isNone(3)) { - // aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - // This leads to problem in the operator argument checks, - // when either `start` or `end` is complex and dtype is None - const auto options = TensorOptions() - .dtype(_r.scalartypeOptional(4)) - .device(_r.device(6)) - .layout(_r.layoutOptional(5)) - .requires_grad(_r.toBool(8)) - .pinned_memory(_r.toBool(7)); - torch::utils::maybe_initialize_cuda(options); - - auto dispatch_linspace = [](Scalar start, Scalar end, c10::optional steps, TensorOptions options) -> Tensor { - pybind11::gil_scoped_release no_gil; - return torch::linspace(start, end, steps, options); - }; - return wrap(dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), options)); - } else { - // aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!) - check_out_type_matches(_r.tensor(3), _r.scalartype(4), - _r.isNone(4), _r.layoutOptional(5), - _r.device(6), _r.isNone(6)); - - auto dispatch_linspace_out = [](Tensor out, Scalar start, Scalar end, c10::optional steps) -> Tensor { - pybind11::gil_scoped_release no_gil; - return at::linspace_out(out, start, end, steps); - }; - return wrap(dispatch_linspace_out(_r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64Optional(2)).set_requires_grad(_r.toBool(8))); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -// logspace -static PyObject * THPVariable_logspace(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "logspace(Scalar start, Scalar end, int64_t? steps=None, double base=10.0, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }, /*traceable=*/true); - - ParsedArgs<10> parsed_args; - auto _r = parser.parse(nullptr, args, kwargs, parsed_args); - if(_r.has_torch_function()) { - return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); - } - if (_r.isNone(4)) { - // aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - - // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) - // This leads to problem in the operator argument checks, - // when either `start` or `end` is complex and dtype is None - const auto options = TensorOptions() - .dtype(_r.scalartypeOptional(5)) - .device(_r.device(7)) - .layout(_r.layoutOptional(6)) - .requires_grad(_r.toBool(9)) - .pinned_memory(_r.toBool(8)); - torch::utils::maybe_initialize_cuda(options); - - auto dispatch_logspace = [](Scalar start, Scalar end, c10::optional steps, double base, TensorOptions options) -> Tensor { - pybind11::gil_scoped_release no_gil; - return torch::logspace(start, end, steps, base, options); - }; - return wrap(dispatch_logspace(_r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), _r.toDouble(3), options)); - } else { - // aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) - check_out_type_matches(_r.tensor(4), _r.scalartype(5), - _r.isNone(5), _r.layoutOptional(6), - _r.device(7), _r.isNone(7)); - - auto dispatch_logspace_out = [](Tensor out, Scalar start, Scalar end, c10::optional steps, double base) -> Tensor { - pybind11::gil_scoped_release no_gil; - return at::logspace_out(out, start, end, steps, base); - }; - return wrap(dispatch_logspace_out(_r.tensor(4), _r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), _r.toDouble(3)).set_requires_grad(_r.toBool(9))); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} +namespace torch { namespace autograd { // generated forward declarations start here ${py_forwards} -// Wrapper converts a raised TypeError into returning NotImplemented -// Used to implement binary arithmetic operators -template -static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { - PyObject* ret = Func(self, args, kwargs); - if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { - PyErr_Clear(); - Py_INCREF(Py_NotImplemented); - ret = Py_NotImplemented; - } - return ret; -} - -// XXX: ops that are bound here are not exposed to the C++ api nor the JIT. -// Any new ops added here should be accompanied with a comment why they are not -// being registered through native_functions.yaml, and be tagged cpp / JIT -static PyMethodDef torch_functions[] = { - {"arange", castPyCFunctionWithKeywords(THPVariable_arange), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, - {"frombuffer", castPyCFunctionWithKeywords(THPVariable_frombuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"linspace", castPyCFunctionWithKeywords(THPVariable_linspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"logspace", castPyCFunctionWithKeywords(THPVariable_logspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_validate_sparse_csr_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_csr_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, +static PyMethodDef torch_functions_shard[] = { ${py_method_defs} - {NULL} }; -static PyTypeObject THPVariableFunctions = { - PyVarObject_HEAD_INIT(NULL, 0) - "torch._C._VariableFunctionsClass", /* tp_name */ - 0, /* tp_basicsize */ - 0, /* tp_itemsize */ - 0, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - NULL, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - torch_functions, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - 0 /* tp_new */ -}; - -void initTorchFunctions(PyObject* module) { - if (PyType_Ready(&THPVariableFunctions) < 0) { - throw python_error(); - } - Py_INCREF(&THPVariableFunctions); - - // Steals - Py_INCREF(&THPVariableFunctions); - if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast(&THPVariableFunctions)) < 0) { - throw python_error(); - } - // PyType_GenericNew returns a new reference - THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None); - // PyModule_AddObject steals a reference - if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) { - throw python_error(); - } +void gatherTorchFunctions${shard_id}(std::vector &torch_functions) { + constexpr size_t num_functions = sizeof(torch_functions_shard) / sizeof(torch_functions_shard[0]); + torch_functions.insert( + torch_functions.end(), + torch_functions_shard, + torch_functions_shard + num_functions); } // generated methods start here ${py_methods} -static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)", - }); - ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if(r.has_torch_function()){ - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - const auto as_tuple = r.toBool(1); - const auto has_out = !r.isNone(2); - - if (as_tuple) { - TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True"); - return wrap(dispatch_nonzero_numpy(r.tensor(0))); - } - - if (has_out) { - return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2))); - } - - return wrap(dispatch_nonzero(r.tensor(0))); - - END_HANDLE_TH_ERRORS -} - -static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "numel(Tensor input)", - }, /*traceable=*/false); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - - if(r.has_torch_function()){ - return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); - } - - if (r.idx == 0) { - return wrap(r.tensor(0).numel()); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} }} // namespace torch::autograd diff --git a/tools/build_libtorch.py b/tools/build_libtorch.py index 800d8eb278481..c263e5084f783 100644 --- a/tools/build_libtorch.py +++ b/tools/build_libtorch.py @@ -14,7 +14,10 @@ if __name__ == '__main__': # Placeholder for future interface. For now just gives a nice -h. parser = argparse.ArgumentParser(description='Build libtorch') + parser.add_argument('--rerun-cmake', action="store_true", help='rerun cmake') + parser.add_argument('--cmake-only', action="store_true", + help='Stop once cmake terminates. Leave users a chance to adjust build options') options = parser.parse_args() build_caffe2(version=None, cmake_python_library=None, build_python=False, - rerun_cmake=True, cmake_only=False, cmake=CMake()) + rerun_cmake=options.rerun_cmake, cmake_only=options.cmake_only, cmake=CMake()) diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 89697b4428ca1..363503d89f9f5 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -23,7 +23,9 @@ GENERATED_CPP = [ "autograd/generated/python_fft_functions.cpp", "autograd/generated/python_linalg_functions.cpp", "autograd/generated/python_special_functions.cpp", - "autograd/generated/python_torch_functions.cpp", + "autograd/generated/python_torch_functions_0.cpp", + "autograd/generated/python_torch_functions_1.cpp", + "autograd/generated/python_torch_functions_2.cpp", "autograd/generated/python_variable_methods.cpp", ] @@ -126,6 +128,7 @@ libtorch_edge_profiler_sources = libtorch_profiler_sources + [ core_trainer_sources = [ "torch/csrc/autograd/anomaly_mode.cpp", "torch/csrc/autograd/autograd.cpp", + "torch/csrc/autograd/autograd_not_implemented_fallback.cpp", "torch/csrc/autograd/cpp_hook.cpp", "torch/csrc/autograd/custom_function.cpp", "torch/csrc/autograd/engine.cpp", @@ -244,6 +247,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/passes/symbolic_shape_analysis.cpp", "torch/csrc/jit/passes/specialize_autogradzero.cpp", "torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp", + "torch/csrc/jit/passes/variadic_ops.cpp", "torch/csrc/jit/passes/subgraph_rewrite.cpp", "torch/csrc/jit/passes/tensorexpr_fuser.cpp", "torch/csrc/jit/passes/utils/memory_dag.cpp", @@ -300,7 +304,6 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/tensorexpr/llvm_codegen.cpp", "torch/csrc/jit/tensorexpr/llvm_jit.cpp", "torch/csrc/jit/tensorexpr/loopnest.cpp", - "torch/csrc/jit/tensorexpr/mem_arena.cpp", "torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp", "torch/csrc/jit/tensorexpr/operators/conv2d.cpp", "torch/csrc/jit/tensorexpr/operators/matmul.cpp", @@ -316,7 +319,7 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/jit/testing/hooks_for_testing.cpp", "torch/csrc/utils/tensor_flatten.cpp", "torch/csrc/utils/variadic.cpp", -] + libtorch_profiler_sources +] core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [ "torch/csrc/jit/backends/backend_debug_info.cpp", @@ -329,14 +332,16 @@ core_sources_full = core_sources_full_mobile + [ "torch/csrc/jit/runtime/static/native_ops.cpp", "torch/csrc/jit/runtime/static/ops.cpp", "torch/csrc/jit/runtime/static/passes.cpp", + "torch/csrc/jit/runtime/static/te_wrapper.cpp", "torch/csrc/jit/tensorexpr/external_functions.cpp", "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp", ] -libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources) +libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_trainer_sources + libtorch_profiler_sources) # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ + "torch/csrc/distributed/c10d/frontend.cpp", "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", @@ -348,6 +353,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupGloo.cpp", "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", + "torch/csrc/distributed/c10d/quantization/quantization.cpp", "torch/csrc/distributed/c10d/reducer.cpp", "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/Store.cpp", @@ -545,9 +551,11 @@ libtorch_cuda_distributed_base_sources = [ # These files are only supported on Linux (and others) but not on Windows. libtorch_cuda_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/frontend_cuda.cpp", "torch/csrc/distributed/c10d/NCCLUtils.cpp", "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", + "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + libtorch_cuda_distributed_extra_sources @@ -663,6 +671,7 @@ libtorch_python_core_sources = [ "torch/csrc/autograd/python_function.cpp", "torch/csrc/autograd/python_hook.cpp", "torch/csrc/autograd/python_legacy_variable.cpp", + "torch/csrc/autograd/python_torch_functions_manual.cpp", "torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable_indexing.cpp", "torch/csrc/jit/backends/backend_init.cpp", @@ -730,7 +739,6 @@ libtorch_python_core_sources = [ ] libtorch_python_distributed_core_sources = [ - "torch/csrc/distributed/c10d/frontend.cpp", "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/python_comm_hook.cpp", ] @@ -759,7 +767,9 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): "autograd/generated/python_fft_functions.cpp", "autograd/generated/python_linalg_functions.cpp", "autograd/generated/python_special_functions.cpp", - "autograd/generated/python_torch_functions.cpp", + "autograd/generated/python_torch_functions_0.cpp", + "autograd/generated/python_torch_functions_1.cpp", + "autograd/generated/python_torch_functions_2.cpp", "autograd/generated/python_variable_methods.cpp", ]] @@ -829,6 +839,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/detail/CPUGuardImpl.cpp", "aten/src/ATen/detail/CUDAHooksInterface.cpp", "aten/src/ATen/detail/HIPHooksInterface.cpp", + "aten/src/ATen/detail/ORTHooksInterface.cpp", "aten/src/ATen/metal/Context.cpp", "aten/src/ATen/native/AutogradComposite.cpp", "aten/src/ATen/native/BatchLinearAlgebraKernel.cpp", @@ -854,6 +865,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/native/mkldnn/TensorShape.cpp", "aten/src/ATen/native/mkldnn/UnaryOps.cpp", "aten/src/ATen/native/mkldnn/Utils.cpp", + "aten/src/ATen/native/mkldnn/Matmul.cpp", "aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp", "aten/src/ATen/record_function.cpp", "aten/src/ATen/SavedTensorHooks.cpp", @@ -896,6 +908,7 @@ aten_native_source_codegen_list = [ "aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp", "aten/src/ATen/native/cpu/MaxPooling.cpp", "aten/src/ATen/native/cpu/MaxPoolKernel.cpp", + "aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp", "aten/src/ATen/native/cpu/MultinomialKernel.cpp", "aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp", "aten/src/ATen/native/cpu/PowKernel.cpp", diff --git a/tools/codegen/api/structured.py b/tools/codegen/api/structured.py index 4f1437fb6f3ff..6aab794413c64 100644 --- a/tools/codegen/api/structured.py +++ b/tools/codegen/api/structured.py @@ -84,7 +84,27 @@ def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[B def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] - args.extend(g.out.func.arguments.non_out) + + if g.out.precomputed: + # A list of parameters for the impl function with + # certain parameters replaced with precomputed counterparts + # as specified in native_functions.yaml. + non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] + + for a in g.out.func.arguments.non_out: + if isinstance(a, Argument) and a.name in g.out.precomputed.replace: + # If a is in precompute.replace, append the parameters + # that should replace it onto non_out_args_replaced. + for replacement in g.out.precomputed.replace[a.name]: + non_out_args_replaced.append(replacement) + else: + # If not, push a as it is. + non_out_args_replaced.append(a) + + args.extend(non_out_args_replaced) + else: + args.extend(g.out.func.arguments.non_out) + args.extend(g.out.func.arguments.out) return [r for arg in args for r in argument(arg)] diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py index ab4bada277572..441e4426cf29e 100644 --- a/tools/codegen/dest/__init__.py +++ b/tools/codegen/dest/__init__.py @@ -1,2 +1,3 @@ from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey +from .register_dispatch_key import gen_registration_helpers as gen_registration_helpers from .native_functions import compute_native_function_declaration as compute_native_function_declaration diff --git a/tools/codegen/dest/register_dispatch_key.py b/tools/codegen/dest/register_dispatch_key.py index a943f51ba5898..ec3a2e6afc0b1 100644 --- a/tools/codegen/dest/register_dispatch_key.py +++ b/tools/codegen/dest/register_dispatch_key.py @@ -23,6 +23,79 @@ from tools.codegen.api.translate import translate from tools.codegen.selective_build.selector import SelectiveBuilder + +def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: + if backend_index.dispatch_key == DispatchKey.Meta: + # TODO: dedupe this with below + core = """ +if (strides.empty()) { + return at::empty(sizes, options.device(at::kMeta)); +} else { + return at::empty_strided(sizes, strides, options.device(at::kMeta)); +} +""" + else: + expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ + "options.device_opt(), options.pinned_memory_opt()" + empty_init = "" + if backend_index.dispatch_key == DispatchKey.CPU: + empty_impl = "at::native::empty_cpu" + empty_strided_impl = "at::native::empty_strided_cpu" + elif backend_index.dispatch_key == DispatchKey.CUDA: + empty_init = "globalContext().lazyInitCUDA();" + empty_impl = "at::native::empty_cuda" + empty_strided_impl = "at::native::empty_strided_cuda" + elif backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: + empty_impl = "at::empty" + empty_strided_impl = "at::empty_strided" + else: + return [] + core = f""" + {empty_init} + if (strides.empty()) {{ + return {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); + }} else {{ + // TODO: assert options.memory_format_opt() is nullopt (debug only?) + return {empty_strided_impl}(sizes, strides, {expanded_topts}); + }} +""" + return [f""" +Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ +{core} +}} +"""] + + +def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: + return [""" +void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { + TORCH_CHECK(options.dtype() == out.dtype(), + "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); + TORCH_CHECK(options.device() == out.device(), + "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); + const bool resized = at::native::resize_output(out, sizes); + // Only restride if a resize occurred; otherwise we ignore the (advisory) + // strides from the meta function and directly use the output tensor's + // preexisting strides + if (resized) { + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + at::native::as_strided_(out, sizes, strides); + } else if (options.memory_format_opt().has_value()) { + out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } +} +"""] + + +def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: + return [ + *gen_create_out_helper(backend_index), + *gen_resize_out_helper(backend_index) + ] + + # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). # # - The primary function of this file is to register all of the @@ -344,62 +417,17 @@ def gen_class_set_output_body(self, k: SchemaKind) -> str: maybe_set_guard_line = maybe_set_guard = '' if k is SchemaKind.functional: - if self.backend_index.dispatch_key == DispatchKey.Meta: - # TODO: dedupe this with below - return """ -if (strides.empty()) { - outputs_[output_idx] = at::empty(sizes, options.device(at::kMeta)); -} else { - outputs_[output_idx] = at::empty_strided(sizes, strides, options.device(at::kMeta)); -} -""" - else: - expanded_topts = "optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), " \ - "options.device_opt(), options.pinned_memory_opt()" - empty_init = "" - if self.backend_index.dispatch_key == DispatchKey.CPU: - empty_impl = "at::native::empty_cpu" - empty_strided_impl = "at::native::empty_strided_cpu" - elif self.backend_index.dispatch_key == DispatchKey.CUDA: - empty_init = "globalContext().lazyInitCUDA();" - empty_impl = "at::native::empty_cuda" - empty_strided_impl = "at::native::empty_strided_cuda" - elif self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: - empty_impl = "at::empty" - empty_strided_impl = "at::empty_strided" - else: - raise AssertionError("unsupported dispatch key") - return f"""{maybe_set_guard_line} -{empty_init} -if (strides.empty()) {{ - outputs_[output_idx] = {empty_impl}(sizes, {expanded_topts}, options.memory_format_opt()); -}} else {{ - // TODO: assert options.memory_format_opt() is nullopt (debug only?) - outputs_[output_idx] = {empty_strided_impl}(sizes, strides, {expanded_topts}); -}} -""" + assert self.backend_index.dispatch_key in ( + DispatchKey.Meta, DispatchKey.CPU, DispatchKey.CUDA, + DispatchKey.CompositeExplicitAutograd) + return f"""{maybe_set_guard_line} +outputs_[output_idx] = create_out(sizes, strides, options);""" elif k is SchemaKind.inplace: return maybe_set_guard elif k is SchemaKind.out: return f"""{maybe_set_guard_line} const auto& out = outputs_[output_idx].get(); -TORCH_CHECK(options.dtype() == out.dtype(), - "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); -TORCH_CHECK(options.device() == out.device(), - "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); -bool resized = at::native::resize_output(outputs_[output_idx], sizes); -// Only restride if a resize occurred; otherwise we ignore the (advisory) -// strides from the meta function and directly use the output tensor's -// preexisting strides -if (resized) {{ - if (!strides.empty()) {{ - TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); - at::native::as_strided_(outputs_[output_idx], sizes, strides); - }} else if (options.memory_format_opt().has_value()) {{ - outputs_[output_idx].get().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); - }} -}} -""" +resize_out(out, sizes, strides, options);""" else: assert_never(k) @@ -556,7 +584,29 @@ def generate_defn(cpp_sig: CppSignature) -> str: method=False ) ) - sig_body.append(f"op.meta({meta_exprs});") + + if self.g.out.precomputed: + # If this function group has precomputed elements, the meta function + # returns a struct containing them which must be saved so that it + # can be unpacked when generating code to call the impl. + sig_body.append(f"auto precompute = op.meta({meta_exprs});") + + # Put all of the contents of the precompute struct into the context + # so that translate will be able to return the correct args for the + # call to the impl. + for precomputed_elems in self.g.out.precomputed.replace.values(): + for arg in precomputed_elems: + context.append(Expr( + expr=f"precompute.{arg.name}", + type=structured.argument_type(arg, binds=arg.name), + )) + + # Add a use of the precompute struct so FB internal compilers don't + # complain that there is an unused variable. + sig_body.append("(void)precompute;") + else: + sig_body.append(f"op.meta({meta_exprs});") + # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index c0ce886c3d50a..c986f8311604d 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -1,5 +1,5 @@ import os -from typing import List, Dict, Optional, Tuple, Set, Callable, Any, Union, Sequence, TypeVar +from typing import List, Dict, Optional, Tuple, Set, Callable, Any, Union, Sequence, TypeVar, Iterable from typing_extensions import Literal import yaml from collections import OrderedDict, defaultdict, namedtuple @@ -456,9 +456,98 @@ def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]: parent_class = g.out.structured_inherits if parent_class is None: parent_class = "at::impl::MetaBase" + meta_return = "void" + precomputed = g.out.precomputed if g.structured else None + + if precomputed: + # Generate the template declaration with one bool parameter for each + # precomputed element. Each parameter is true if the corresponding (in + # terms of position) precomputed element has been set. + precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list] + precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements] + precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters) + precompute_template_decl = f"template <{precomputed_template_params_str}>" + + # Generate a string containing declarations of all precomputed elements. + precomputed_elements_with_cpp_types = [ + structured.argument_type(elem, binds=elem.name) + for elem in precomputed_elements + ] + + precomputed_elements_decl = ";\n".join( + f"{elem.cpp_type(strip_ref=True)} {elem.name}" for elem in precomputed_elements_with_cpp_types + ) + + # Generate "setter" methods for each precomputed element. Each method will return + # a new instance of precompute_out with the template parameter that corresponds to + # the member set by the method to true (to indicate that it has been set). + setter_methods = [] + for i, elem in enumerate(precomputed_elements): + # Generate the signature. The return type will be the same + # as the type of `this` but with the template parameter + # corresponding to the element set by this method set to true. + # The assert generated below will ensure that this template + # parameter is false on the type of `this`. + return_ty_templates = ", ".join( + precomputed_template_parameters[:i] + ["true"] + precomputed_template_parameters[i + 1:] + ) + return_ty = f"precompute_out<{return_ty_templates}>" + elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(strip_ref=True) + signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)" + + # Generate an assert which checks that the + # template parameter corresponding to the precomputed + # element that is set by this method is false on the + # class corresponding to the object that `this` points to. + # This ensures that each element can be set only once. + assert_msg = f"\"{precomputed_elements[i].name} already set\"" + assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});" + + # Generate the new object construction block. All state + # except the element that this method sets is copied from the + # object that `this` points to. The value for the element that + # the method sets is taken from a method parameter. + construction_stmts = [] + construction_stmts.append(f"{return_ty} ret;") + + for j, elem in enumerate(precomputed_elements): + if i == j: + construction_stmts.append(f"ret.{elem.name} = value;") + else: + construction_stmts.append(f"ret.{elem.name} = this->{elem.name};") + + construction_stmts.append("return ret;") + construction_block = "\n".join(construction_stmts) + + setter_methods.append(f""" + {signature} {{ + {assert_stmt} + {construction_block} + }} + """) + setter_methods_decl = "\n".join(setter_methods) + + # Meta should return an instance of the struct containing the precomputed elements. + meta_return_template_params = ", ".join(["true"] * len(precomputed_template_parameters)) + # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return + # type (which has a variable number of template parameters). + meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;" + meta_return = "meta_return_ty" + precomputed_decl = f""" + {precompute_template_decl} + struct TORCH_API precompute_out {{ + {setter_methods_decl} + {precomputed_elements_decl}; + }};""" + else: + meta_return_typedef = "" + precomputed_decl = "" + return f"""\ struct TORCH_API structured_{name} : public {parent_class} {{ - void meta({args_str}); + {precomputed_decl} + {meta_return_typedef} + {meta_return} meta({args_str}); }}; """ @@ -858,7 +947,7 @@ def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, def write_sharded( self, filename: str, - items: List[T], + items: Iterable[T], *, key_fn: Callable[[T], str], env_callable: Callable[[T], Dict[str, List[str]]], @@ -1096,13 +1185,11 @@ def make_file_manager(install_dir: str) -> FileManager: fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '', - 'legacy_th_headers': - '#include ' if dispatch_key == DispatchKey.CUDA else - '', 'external_backend_headers': '', 'namespaced_headers': f'#include ' if dispatch_key in functions_keys else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), + 'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]), 'dispatch_namespaced_definitions': list(concatMap( dest.RegisterDispatchKey( backend_indices[dispatch_key], diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py index a712a239ad565..5fad11c343804 100644 --- a/tools/codegen/gen_backend_stubs.py +++ b/tools/codegen/gen_backend_stubs.py @@ -227,11 +227,11 @@ def make_file_manager(install_dir: str) -> FileManager: for dispatch_key in [backend_dispatch_key, autograd_dispatch_key]: fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', - 'legacy_th_headers': '', 'external_backend_headers': f'#include "{output_dir}/{backend_key}NativeFunctions.h"', 'namespaced_headers': '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), + 'dispatch_helpers': dest.gen_registration_helpers(backend_indices[dispatch_key]), 'dispatch_namespaced_definitions': list(concatMap( dest.RegisterDispatchKey( backend_indices[dispatch_key], diff --git a/tools/codegen/model.py b/tools/codegen/model.py index d6f02d5a6898d..e604e72d3a1ad 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -56,7 +56,7 @@ class DispatchKey(Enum): CUDA = auto() HIP = auto() FPGA = auto() - MSNPU = auto() + ORT = auto() XLA = auto() Lazy = auto() Vulkan = auto() @@ -229,6 +229,14 @@ class NativeFunction: # changes the semantics of set_output to call the parent class. structured_inherits: Optional[str] + # Structured kernels can declare elements as "precomputed". These elements + # are returned by the meta function in one struct and passed to the impl + # function in lieu of certain kernel arguments that these precomputed + # elements supersede. Information about the names and types of these + # precomputed elements and how they correspond to kernel arguments is stored + # in this member, if applicable. + precomputed: Optional['Precompute'] + # Argument names whose default should be excluded from the C++ interface. # Intended for resolving overload ambiguities between signatures. cpp_no_default_args: Set[str] @@ -320,6 +328,10 @@ def from_yaml( category_override = e.pop('category_override', None) assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}' + precomputed_dict = e.pop('precomputed', None) + assert precomputed_dict is None or structured is True + precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None + from tools.codegen.api import cpp raw_dispatch = e.pop('dispatch', None) @@ -389,6 +401,7 @@ def from_yaml( structured=structured, structured_delegate=structured_delegate, structured_inherits=structured_inherits, + precomputed=precomputed, manual_kernel_registration=manual_kernel_registration, manual_cpp_binding=manual_cpp_binding, python_module=python_module, @@ -1496,3 +1509,42 @@ def parse_returns(return_decl: str) -> Tuple[Return, ...]: if return_decl[0] == '(' and return_decl[-1] == ')': return_decl = return_decl[1:-1] return tuple(Return.parse(arg) for arg in return_decl.split(', ')) + + +# A Precompute instance consists of a map from kernel argument name +# to the list of Argument instances that should replace that +# kernel argument in the impl function. +@dataclass(frozen=True) +class Precompute: + # A map from kernel argument name -> a list of precomputed + # elements that replaces/supersedes it. + replace: Dict[str, List[Argument]] + + @staticmethod + def parse(src: object) -> 'Precompute': + assert isinstance(src, list) + + # src is a list of strings of the format: + # {kernel param name} -> {replacement decl}[, {replacement decl}, ...] + # Parse this list to get the names of which precomputed elements + # should replace which kernel arguments. + replace = {} + for raw_replace_item in src: + assert isinstance(raw_replace_item, str) + + arg, with_list_raw = raw_replace_item.split(' -> ') + with_list = with_list_raw.split(',') + with_list_args = [Argument.parse(name.strip()) for name in with_list] + replace[arg] = with_list_args + + r = Precompute(replace=replace) + assert r.to_list() == src, 'r.to_list() != src' + return r + + def to_list(self) -> List[str]: + replace_list = [] + for kernel_param, replacement_params in self.replace.items(): + replacements = ', '.join(str(param) for param in replacement_params) + replace_list.append(f'{kernel_param} -> {replacements}') + + return replace_list diff --git a/tools/config/BUILD b/tools/config/BUILD index a8f9d0452fce8..ba13eda2bba7b 100644 --- a/tools/config/BUILD +++ b/tools/config/BUILD @@ -13,7 +13,6 @@ selects.config_setting_group( name = "cuda_enabled_and_capable", match_all = [ ":cuda", - "//tools/toolchain:is_cuda_capable", ], ) diff --git a/tools/linter/clang_tidy/__main__.py b/tools/linter/clang_tidy/__main__.py index fc9f2ab4e6687..1846916c26f3d 100644 --- a/tools/linter/clang_tidy/__main__.py +++ b/tools/linter/clang_tidy/__main__.py @@ -74,6 +74,7 @@ def clang_search_dirs() -> List[str]: "-torch/csrc/deploy/interpreter/interpreter.h", "-torch/csrc/deploy/interpreter/interpreter_impl.h", "-torch/csrc/deploy/interpreter/test_main.cpp", + "-torch/csrc/deploy/test_deploy_python_ext.cpp", ], "paths": ["torch/csrc/"], "include-dir": ["/usr/lib/llvm-11/include/openmp"] + clang_search_dirs(), @@ -183,7 +184,8 @@ def main() -> None: f"Could not find '{options.clang_tidy_exe}'\n" + "We provide a custom build of clang-tidy that has additional checks.\n" + "You can install it by running:\n" - + "$ python3 tools/linter/install/clang_tidy.py" + + "$ python3 -m tools.linter.install.clang_tidy \n" + + "from the pytorch folder" ) raise RuntimeError(msg) diff --git a/tools/nightly.py b/tools/nightly.py index 0b387e3b32dcf..7a46a011d232a 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -324,7 +324,7 @@ def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> No @timed("Installing pytorch nightly binaries") -def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]: +def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]": """"Install pytorch into a temporary directory""" pytdir = tempfile.TemporaryDirectory() cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url] diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 4f39fec2188fc..882b7f114e2e3 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -469,6 +469,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) - 'is_sparse_csr' : ['is_sparse_csr: _bool'], 'is_quantized': ['is_quantized: _bool'], 'is_meta': ['is_meta: _bool'], + 'is_ort': ['is_ort: _bool'], 'is_mkldnn': ['is_mkldnn: _bool'], 'is_vulkan': ['is_vulkan: _bool'], 'storage_offset': ['def storage_offset(self) -> _int: ...'], diff --git a/tools/rules/workspace.bzl b/tools/rules/workspace.bzl index 59e12e8d92d03..34317bec25f5d 100644 --- a/tools/rules/workspace.bzl +++ b/tools/rules/workspace.bzl @@ -27,3 +27,28 @@ pkg_tar(name = "content", srcs = glob(["**"])) path = path, ) _patched_rule(name = name, **kwargs) + +def _new_empty_repository_impl(repo_ctx): + build_file = repo_ctx.attr.build_file + build_file_content = repo_ctx.attr.build_file_content + if not (bool(build_file) != bool(build_file_content)): + fail("Exactly one of 'build_file' or 'build_file_content' is required") + + if build_file_content: + repo_ctx.file("BUILD", build_file_content) + elif build_file: + repo_ctx.template("BUILD", repo_ctx.attr.build_file, {}) + +new_empty_repository = repository_rule( + attrs = { + "build_file": attr.label(allow_files = True), + "build_file_content": attr.string(), + }, + implementation = _new_empty_repository_impl, +) + +"""Create an empty repository with the supplied BUILD file. + +This is mostly useful to create wrappers for specific target that we want +to be used with the '@' syntax. +""" diff --git a/tools/stats/print_test_stats.py b/tools/stats/print_test_stats.py index 7cc853e925181..1f4c33e8feb43 100755 --- a/tools/stats/print_test_stats.py +++ b/tools/stats/print_test_stats.py @@ -630,7 +630,6 @@ def __init__(self, name: str) -> None: def append(self, test_case: TestCase, test_type: str) -> None: is_multi_test = self.name == 'test_cpp_extensions_aot' or \ - self.name == 'distributed/test_distributed_fork' or \ self.name == 'distributed/test_distributed_spawn' or \ self.name == 'distributed/test_c10d_gloo' or \ self.name == 'cpp' # The caffe2 cpp tests spawn duplicate test cases as well. @@ -782,14 +781,16 @@ def assemble_s3_object( def send_report_to_s3(head_report: Version2Report) -> None: job = os.getenv('JOB_BASE_NAME', os.environ.get('CIRCLE_JOB')) + # SHARD_NUMBER is specific to GHA jobs, as the shard number would be included in CIRCLE_JOB already + shard = os.environ.get('SHARD_NUMBER', '') sha1 = os.environ.get('CIRCLE_SHA1') branch = os.environ.get('CIRCLE_BRANCH', '') now = datetime.datetime.utcnow().isoformat() if branch not in ['master', 'nightly'] and not branch.startswith("release/"): pr = os.environ.get('CIRCLE_PR_NUMBER', 'unknown') - key = f'pr_test_time/{pr}/{sha1}/{job}/{now}Z.json.bz2' # Z meaning UTC + key = f'pr_test_time/{pr}/{sha1}/{job}{shard}/{now}Z.json.bz2' # Z meaning UTC else: - key = f'test_time/{sha1}/{job}/{now}Z.json.bz2' # Z meaning UTC + key = f'test_time/{sha1}/{job}{shard}/{now}Z.json.bz2' # Z meaning UTC obj = get_S3_object_from_bucket('ossci-metrics', key) # use bz2 because the results are smaller than gzip, and the # compression time penalty we pay is only about half a second for diff --git a/tools/test/test_extract_scripts.py b/tools/test/test_extract_scripts.py index 29802517963b3..3126893c4bb39 100644 --- a/tools/test/test_extract_scripts.py +++ b/tools/test/test_extract_scripts.py @@ -20,7 +20,7 @@ def test_extract_none(self) -> None: self.assertEqual( extract_scripts.extract({ 'name': 'Checkout PyTorch', - 'uses': 'actions/checkout@v2', + 'uses': 'zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9', }), None, ) diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py new file mode 100644 index 0000000000000..32dc1031b5616 --- /dev/null +++ b/tools/testing/modulefinder_determinator.py @@ -0,0 +1,228 @@ +import os +import modulefinder +import sys +import pathlib +import warnings +from typing import Dict, Any, List, Set + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent + +# These tests are slow enough that it's worth calculating whether the patch +# touched any related files first. This list was manually generated, but for every +# run with --determine-from, we use another generated list based on this one and the +# previous test stats. +TARGET_DET_LIST = [ + "distributed/algorithms/ddp_comm_hooks/test_ddp_hooks", + "distributed/nn/jit/test_instantiator", + "distributed/pipeline/sync/skip/test_api", + "distributed/pipeline/sync/skip/test_gpipe", + "distributed/pipeline/sync/skip/test_inspect_skip_layout", + "distributed/pipeline/sync/skip/test_leak", + "distributed/pipeline/sync/skip/test_portal", + "distributed/pipeline/sync/skip/test_stash_pop", + "distributed/pipeline/sync/skip/test_tracker", + "distributed/pipeline/sync/skip/test_verify_skippables", + "distributed/pipeline/sync/test_balance", + "distributed/pipeline/sync/test_bugs", + "distributed/pipeline/sync/test_checkpoint", + "distributed/pipeline/sync/test_copy", + "distributed/pipeline/sync/test_deferred_batch_norm", + "distributed/pipeline/sync/test_dependency", + "distributed/pipeline/sync/test_inplace", + "distributed/pipeline/sync/test_microbatch", + "distributed/pipeline/sync/test_phony", + "distributed/pipeline/sync/test_pipe", + "distributed/pipeline/sync/test_pipeline", + "distributed/pipeline/sync/test_stream", + "distributed/pipeline/sync/test_transparency", + "distributed/pipeline/sync/test_worker", + "distributed/rpc/cuda/test_tensorpipe_agent", + "distributed/rpc/test_tensorpipe_agent", + "distributed/test_c10d_common", + "distributed/test_c10d_gloo", + "distributed/test_c10d_nccl", + "distributed/test_c10d_spawn_gloo", + "distributed/test_c10d_spawn_nccl", + "distributed/test_distributed_spawn", + "distributed/test_jit_c10d", + "distributed/test_pg_wrapper", + "distributed/test_store", + "distributions/test_distributions", + # test_autograd.py is not slow, so it does not belong here. But + # note that if you try to add it back it will run into + # https://bugs.python.org/issue40350 because it imports files + # under test/autograd/. + "test_binary_ufuncs", + "test_cpp_extensions_aot_ninja", + "test_cpp_extensions_aot_no_ninja", + "test_cpp_extensions_jit", + "test_cuda", + "test_cuda_primary_ctx", + "test_dataloader", + "test_determination", + "test_futures", + "test_jit", + "test_jit_legacy", + "test_jit_profiling", + "test_linalg", + "test_multiprocessing", + "test_nn", + "test_numpy_interop", + "test_optim", + "test_overrides", + "test_pruning_op", + "test_quantization", + "test_reductions", + "test_serialization", + "test_shape_ops", + "test_sort_and_select", + "test_tensorboard", + "test_testing", + "test_torch", + "test_utils", + "test_view_ops", +] + + +_DEP_MODULES_CACHE: Dict[str, Set[str]] = {} + + +def should_run_test( + target_det_list: List[str], test: str, touched_files: List[str], options: Any +) -> bool: + test = parse_test_module(test) + # Some tests are faster to execute than to determine. + if test not in target_det_list: + if options.verbose: + print_to_stderr(f"Running {test} without determination") + return True + # HACK: "no_ninja" is not a real module + if test.endswith("_no_ninja"): + test = test[: (-1 * len("_no_ninja"))] + if test.endswith("_ninja"): + test = test[: (-1 * len("_ninja"))] + + dep_modules = get_dep_modules(test) + + for touched_file in touched_files: + file_type = test_impact_of_file(touched_file) + if file_type == "NONE": + continue + elif file_type == "CI": + # Force all tests to run if any change is made to the CI + # configurations. + log_test_reason(file_type, touched_file, test, options) + return True + elif file_type == "UNKNOWN": + # Assume uncategorized source files can affect every test. + log_test_reason(file_type, touched_file, test, options) + return True + elif file_type in ["TORCH", "CAFFE2", "TEST"]: + parts = os.path.splitext(touched_file)[0].split(os.sep) + touched_module = ".".join(parts) + # test/ path does not have a "test." namespace + if touched_module.startswith("test."): + touched_module = touched_module.split("test.")[1] + if touched_module in dep_modules or touched_module == test.replace( + "/", "." + ): + log_test_reason(file_type, touched_file, test, options) + return True + + # If nothing has determined the test has run, don't run the test. + if options.verbose: + print_to_stderr(f"Determination is skipping {test}") + + return False + + +def test_impact_of_file(filename: str) -> str: + """Determine what class of impact this file has on test runs. + + Possible values: + TORCH - torch python code + CAFFE2 - caffe2 python code + TEST - torch test code + UNKNOWN - may affect all tests + NONE - known to have no effect on test outcome + CI - CI configuration files + """ + parts = filename.split(os.sep) + if parts[0] in [".jenkins", ".circleci"]: + return "CI" + if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]: + return "NONE" + elif parts[0] == "torch": + if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): + return "TORCH" + elif parts[0] == "caffe2": + if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): + return "CAFFE2" + elif parts[0] == "test": + if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): + return "TEST" + + return "UNKNOWN" + + +def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None: + if options.verbose: + print_to_stderr( + "Determination found {} file {} -- running {}".format( + file_type, + filename, + test, + ) + ) + + +def get_dep_modules(test: str) -> Set[str]: + # Cache results in case of repetition + if test in _DEP_MODULES_CACHE: + return _DEP_MODULES_CACHE[test] + + test_location = REPO_ROOT / "test" / f"{test}.py" + + # HACK: some platforms default to ascii, so we can't just run_script :( + finder = modulefinder.ModuleFinder( + # Ideally exclude all third party modules, to speed up calculation. + excludes=[ + "scipy", + "numpy", + "numba", + "multiprocessing", + "sklearn", + "setuptools", + "hypothesis", + "llvmlite", + "joblib", + "email", + "importlib", + "unittest", + "urllib", + "json", + "collections", + # Modules below are excluded because they are hitting https://bugs.python.org/issue40350 + # Trigger AttributeError: 'NoneType' object has no attribute 'is_package' + "mpl_toolkits", + "google", + "onnx", + # Triggers RecursionError + "mypy", + ], + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + finder.run_script(str(test_location)) + dep_modules = set(finder.modules.keys()) + _DEP_MODULES_CACHE[test] = dep_modules + return dep_modules + + +def parse_test_module(test: str) -> str: + return test.split(".")[0] + + +def print_to_stderr(message: str) -> None: + print(message, file=sys.stderr) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 90504f025c4a3..7c086855612ca 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -2,7 +2,7 @@ # Now it only builds the Torch python bindings. if(NOT CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) - cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(torch CXX C) find_package(torch REQUIRED) option(USE_CUDA "Use CUDA" ON) @@ -214,11 +214,78 @@ add_custom_command( WORKING_DIRECTORY "${TORCH_ROOT}" ) +if(USE_DISTRIBUTED) + if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) + else() + append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) + endif() + # Disable certain warnings for GCC-9.X + if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + endif() + # NCCL is a private dependency of libtorch, but libtorch_python includes + # some private headers of libtorch, which in turn include NCCL. As a hacky + # alternative to making NCCL a public dependency of libtorch, we make it + # a private dependency of libtorch_python as well. + if(USE_NCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) + endif() + # Same for MPI. + if(USE_MPI) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) + endif() + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) + +endif() + +if(USE_NCCL AND NOT WIN32) + list(APPEND TORCH_PYTHON_SRCS + ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) +endif() + # WARNING- any TORCH_PYTHON_COMPILE_DEFINITIONS above this line # affect both torch_python and DEPLOY interpreter. if(USE_DEPLOY) add_library(torch_python_obj OBJECT ${TORCH_PYTHON_SRCS}) + if(USE_DISTRIBUTED) + # Set c10d-related compile definitions. For a "normal" build of + # libtorch_python, these are set on libtorch as PUBLIC so they are + # automatically propagated when libtorch_python links against libtorch. But + # since in the deploy build we are intentionally *not* linking against + # libtorch, we need to set them manually here. + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED) + if(USE_GLOO AND USE_C10D_GLOO) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO) + endif() + if(USE_NCCL AND USE_C10D_NCCL) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL) + # Put nccl headers on the include path. We are specifically only setting + # include dirs here instead of linking against __caffe2_nccl wholesale + # to ensure we aren't accidentally replicating the nccl lib. + target_include_directories(torch_python_obj PRIVATE $) + endif() + if(USE_MPI AND USE_C10D_MPI) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_MPI) + endif() + + # Pass USE_RPC in order to reduce use of + # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) + # need to be removed when RPC is supported + if(NOT WIN32) + target_compile_definitions(torch_cpu PUBLIC USE_RPC) + endif() + if(USE_TENSORPIPE) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_TENSORPIPE) + endif() + + # Set c10d-related include directories as well. + target_include_directories(torch_python_obj PRIVATE $) + endif() target_compile_definitions(torch_python_obj PRIVATE "-DTHP_BUILD_MAIN_LIB -DUSE_DEPLOY") target_compile_definitions(torch_python_obj PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS}) @@ -268,38 +335,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties(${TORCH_SRC_DIR}/csrc/utils/throughput_benchmark.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes) endif() -if(USE_DISTRIBUTED) - if(WIN32) - append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) - else() - append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) - endif() - # Disable certain warnings for GCC-9.X - if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - endif() - # NCCL is a private dependency of libtorch, but libtorch_python includes - # some private headers of libtorch, which in turn include NCCL. As a hacky - # alternative to making NCCL a public dependency of libtorch, we make it - # a private dependency of libtorch_python as well. - if(USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) - endif() - # Same for MPI. - if(USE_MPI) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) - endif() - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) -endif() - -if(USE_NCCL AND NOT WIN32) - list(APPEND TORCH_PYTHON_SRCS - ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) -endif() - add_library(torch_python SHARED ${TORCH_PYTHON_SRCS}) if(HAVE_SOVERSION) set_target_properties(torch_python PROPERTIES diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index b683a60615dc5..091cb097d14e5 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -208,6 +208,7 @@ def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ... def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... +def _jit_can_fuse_on_cpu_legacy() -> _bool: ... def _debug_get_fusion_group_inlining() -> _bool: ... def _debug_set_fusion_group_inlining(enable: _bool): ... def _jit_texpr_fuser_enabled() -> _bool: ... @@ -215,6 +216,7 @@ def _jit_nvfuser_enabled() -> _bool: ... def _llvm_enabled() -> _bool: ... def _jit_override_can_fuse_on_cpu(override: _bool): ... def _jit_override_can_fuse_on_gpu(override: _bool): ... +def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ... def _jit_set_symbolic_shapes_test_mode(override: _bool): ... def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ... def _jit_set_texpr_fuser_enabled(enable: _bool): ... @@ -884,8 +886,13 @@ class _CudaEventBase: def ipc_handle(self) -> bytes: ... # Defined in torch/csrc/cuda/Graph.cpp -class _CudaGraphBase: - ... +class _CUDAGraph: + def capture_begin(self, + pool: Optional[Tuple[_int, _int]]=...) -> None: ... + def capture_end(self) -> None: ... + def replay(self) -> None: ... + def reset(self) -> None: ... + def pool(self) -> Tuple[_int, _int]: ... def _graph_pool_handle() -> Tuple[_int, _int]: ... @@ -994,6 +1001,9 @@ class TupleType(JitType): def __init__(self, a: List[Optional[JitType]]) -> None: ... def elements(self) -> List[JitType]: ... +class UnionType(JitType): + def __init__(self, a: List[JitType]) -> None: ... + class ClassType(JitType): def __init__(self, qualified_name: str) -> None: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 6468eb551f9cd..7ffb618e3f072 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -24,7 +24,7 @@ class DeviceType(Enum): IDEEP = ... HIP = ... FPGA = ... - MSNPU = ... + ORT = ... XLA = ... MLC = ... HPU = ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cfa9c7cc1a46c..50e7602bdd838 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -134,7 +134,8 @@ class TCPStore(Store): world_size: int = ..., is_master: bool = ..., timeout: timedelta = ..., - wait_for_workers: bool = ... + wait_for_workers: bool = ..., + multi_tenant: bool = ... ): ... class PrefixStore(Store): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index bd7b616996a24..806dae6d37f45 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -885,33 +885,28 @@ def is_dict(ann) -> bool: (getattr(ann, '__origin__', None) is Dict or getattr(ann, '__origin__', None) is dict) -def is_optional(ann) -> bool: - if ann is Optional: - raise_error_container_parameter_missing("Optional") +def is_union(ann): + if ann is Union: + raise_error_container_parameter_missing("Union") - # Optional[T] is just shorthand for Union[T, None], so check for both - def safe_is_subclass(the_type, super_type): - # Don't throw if `the_type` isn't a class type (e.g. if it is - # another type annotation instance) - if not inspect.isclass(the_type): - return False - return issubclass(the_type, super_type) + return (hasattr(ann, '__module__') and + ann.__module__ == 'typing' and + (getattr(ann, '__origin__', None) is Union)) - if not hasattr(ann, '__module__'): - return False +def is_optional(ann): + if ann is Optional: + raise_error_container_parameter_missing("Optional") - union_optional = False - if ann.__module__ == 'typing' and \ - (getattr(ann, '__origin__', None) is Union): - args = getattr(ann, '__args__', ()) - if len(args) == 2: - union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \ - or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None))) + def is_optional_as_optional(ann): + return (hasattr(ann, '__module__') and + ann.__module__ == 'typing' and + (getattr(ann, '__origin__', None) is Optional)) - optional = ann.__module__ == 'typing' and \ - (getattr(ann, '__origin__', None) is Optional) + def is_union_as_optional(ann): + ann_args = ann.__args__ + return len(ann_args) == 2 and None in ann_args - return optional or union_optional + return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) def is_future(ann) -> bool: if ann is Future: @@ -1133,7 +1128,7 @@ def check_args_exist(target_type) -> None: def check_empty_containers(obj) -> None: - if not obj: + if obj == [] or obj == {} or obj == (): warnings.warn("The inner type of a container is lost when " "calling torch.jit.isinstance in eager mode. For " "example, List[int] would become list and " @@ -1192,15 +1187,16 @@ def container_checker(obj, target_type) -> bool: elif not isinstance(el, el_type): return False return True - elif origin_type is Union: # actually handles Optional Case + elif origin_type is Union: # also handles Optional if obj is None: # check before recursion because None is always fine return True - optional_type = get_args(target_type)[0] - optional_origin = get_origin(optional_type) - if optional_origin: - return container_checker(obj, optional_type) - elif isinstance(obj, optional_type): - return True + inner_types = get_args(target_type) + for t in inner_types: + t_origin = get_origin(t) + if (t_origin): + return container_checker(obj, t) + elif isinstance(obj, t): + return True return False diff --git a/torch/_tensor.py b/torch/_tensor.py index 2bd617d3971a9..e7bc4ed9165a2 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -90,7 +90,7 @@ def __deepcopy__(self, memo): # does accurate alias tracking; however, the code below # doesn't work because of # https://github.com/pytorch/pytorch/issues/47442 - if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']: + if self.is_sparse or self.device.type in ['xla', 'mlc', 'ort', 'meta']: new_tensor = self.clone() else: new_storage = self.storage().__deepcopy__(memo) @@ -153,28 +153,21 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] torch.utils.hooks.warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() - # Note: Numpy array is chosen to be the rebuild component for XLA Tensor. + # Note: Numpy array is chosen to be the rebuild component for XLA, ORT, MLC Tensors. # We considered a few options: # 1. CPU tensor can't be used here. # Otherwise in torch.load CPU storage is reconstructed with randomly - # initialized data, moved onto XLA device, and then storage is updated - # to the serialized content. This works perfectly for CPU/CUDA but not XLA. - # XLA tensor is disconnected with storage so it doesn't get the update. + # initialized data, moved onto backend device, and then storage is updated + # to the serialized content. This works perfectly for CPU/CUDA but not these backends; + # their tensors are disconnected with storage so they don't get the update. # 2. Python list is not a good fit due to performance reason. # `tolist()` converts every single element in the tensor into python objects # and serialize them one by one. - if self.device.type == 'xla': - arg_xla = (self.cpu().numpy(), - self.dtype, - str(self.device), - self.requires_grad) - return (torch._utils._rebuild_xla_tensor, arg_xla) - if self.device.type == 'mlc': - arg_mlc = (self.cpu().numpy(), - self.dtype, - str(self.device), - self.requires_grad) - return (torch._utils._rebuild_mlc_tensor, arg_mlc) + if self.device.type in ['xla', 'ort', 'mlc']: + return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(), + self.dtype, + str(self.device), + self.requires_grad)) if self.device.type == 'meta': # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. @@ -589,11 +582,21 @@ def __rpow__(self, other): @_wrap_type_error_to_not_implemented def __floordiv__(self, other): - return torch.floor_divide(self, other) + warnings.warn("__floordiv__ is deprecated, and its behavior will change in a future version of pytorch. " + "It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). " + "This results in incorrect rounding for negative values. " + "To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), " + "or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3) + return torch.div(self, other, rounding_mode='trunc') @_wrap_type_error_to_not_implemented def __rfloordiv__(self, other): - return torch.floor_divide(other, self) + warnings.warn("__rfloordiv__ is deprecated, and its behavior will change in a future version of pytorch. " + "It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). " + "This results in incorrect rounding for negative values. " + "To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), " + "or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3) + return torch.div(other, self, rounding_mode='trunc') @_wrap_type_error_to_not_implemented def __rlshift__(self, other): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index e0d3fc011e4b7..bf981a81015e4 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -149,7 +149,7 @@ def merge_dicts(*dicts): absolute(input, *, out=None) -> Tensor Alias for :func:`torch.abs` -""".format(**common_args)) +""") add_docstr(torch.acos, r""" acos(input, *, out=None) -> Tensor @@ -211,28 +211,29 @@ def merge_dicts(*dicts): arccosh(input, *, out=None) -> Tensor Alias for :func:`torch.acosh`. -""".format(**common_args)) +""") add_docstr(torch.add, r""" -add(input, other, *, out=None) -> Tensor +add(input, other, *, alpha=1, out=None) -> Tensor -Adds the scalar :attr:`other` to each element of the input :attr:`input` -and returns a new resulting tensor. +Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. .. math:: - \text{{out}} = \text{{input}} + \text{{other}} + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i +""" + r""" -If :attr:`input` is of type FloatTensor or DoubleTensor, :attr:`other` must be -a real number, otherwise it should be an integer. +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. Args: {input} - other (Number): the number to be added to each element of :attr:`input` + other (Tensor or Number): the tensor or number to add to input. Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. {out} -Example:: +Examples:: >>> a = torch.randn(4) >>> a @@ -240,42 +241,16 @@ def merge_dicts(*dicts): >>> torch.add(a, 20) tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) -.. function:: add(input, other, *, alpha=1, out=None) -> Tensor - :noindex: - -Each element of the tensor :attr:`other` is multiplied by the scalar -:attr:`alpha` and added to each element of the tensor :attr:`input`. -The resulting tensor is returned. - -The shapes of :attr:`input` and :attr:`other` must be -:ref:`broadcastable `. - -.. math:: - \text{{out}} = \text{{input}} + \text{{alpha}} \times \text{{other}} - -If :attr:`other` is of type FloatTensor or DoubleTensor, :attr:`alpha` must be -a real number, otherwise it should be an integer. - -Args: - input (Tensor): the first input tensor - other (Tensor): the second input tensor - -Keyword args: - alpha (Number): the scalar multiplier for :attr:`other` - {out} - -Example:: - - >>> a = torch.randn(4) - >>> a - tensor([-0.9732, -0.3497, 0.6245, 0.4022]) - >>> b = torch.randn(4, 1) + >>> b = torch.randn(4) >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c tensor([[ 0.3743], [-1.7724], [-0.5811], [-0.8017]]) - >>> torch.add(a, b, alpha=10) + >>> torch.add(b, c, alpha=10) tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], [-18.6971, -18.0736, -17.0994, -17.3216], [ -6.7845, -6.1610, -5.1868, -5.4090], @@ -1881,6 +1856,13 @@ def merge_dicts(*dicts): -0.5790, 0.1497]]) """.format(**common_args)) +add_docstr(torch.concat, + r""" +concat(tensors, dim=0, *, out=None) -> Tensor + +Alias of :func:`torch.cat`. +""") + add_docstr(torch.ceil, r""" ceil(input, *, out=None) -> Tensor @@ -2287,7 +2269,7 @@ def merge_dicts(*dicts): clip(input, min=None, max=None, *, out=None) -> Tensor Alias for :func:`torch.clamp`. -""".format(**common_args)) +""") add_docstr(torch.column_stack, r""" @@ -4486,7 +4468,7 @@ def merge_dicts(*dicts): inverse(input, *, out=None) -> Tensor Alias for :func:`torch.linalg.inv` -""".format(**common_args)) +""") add_docstr(torch.isin, r""" isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor @@ -5719,7 +5701,7 @@ def merge_dicts(*dicts): matrix_power(input, n, *, out=None) -> Tensor Alias for :func:`torch.linalg.matrix_power` -""".format(**common_args)) +""") add_docstr(torch.matrix_exp, r""" matrix_exp(input) -> Tensor @@ -6640,23 +6622,24 @@ def merge_dicts(*dicts): add_docstr(torch.mul, r""" mul(input, other, *, out=None) -> Tensor -Multiplies each element of the input :attr:`input` with the scalar -:attr:`other` and returns a new resulting tensor. +Multiplies :attr:`input` by :attr:`other`. + .. math:: - \text{out}_i = \text{other} \times \text{input}_i + \text{out}_i = \text{input}_i \times \text{other}_i """ + r""" -If :attr:`input` is of type `FloatTensor` or `DoubleTensor`, :attr:`other` -should be a real number, otherwise it should be an integer + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. Args: {input} - other (Number): the number to be multiplied to each element of :attr:`input` + other (Tensor or Number) - the tensor or number to multiply input by. Keyword args: {out} -Example:: +Examples:: >>> a = torch.randn(3) >>> a @@ -6664,38 +6647,16 @@ def merge_dicts(*dicts): >>> torch.mul(a, 100) tensor([ 20.1494, -42.5491, 260.8663]) -.. function:: mul(input, other, *, out=None) -> Tensor - :noindex: - -Each element of the tensor :attr:`input` is multiplied by the corresponding -element of the Tensor :attr:`other`. The resulting tensor is returned. - -The shapes of :attr:`input` and :attr:`other` must be -:ref:`broadcastable `. - -.. math:: - \text{{out}}_i = \text{{input}}_i \times \text{{other}}_i -""".format(**common_args) + r""" - -Args: - input (Tensor): the first multiplicand tensor - other (Tensor): the second multiplicand tensor - -Keyword args: - {out} - -Example:: - - >>> a = torch.randn(4, 1) - >>> a + >>> b = torch.randn(4, 1) + >>> b tensor([[ 1.1207], [-0.3137], [ 0.0700], [ 0.8378]]) - >>> b = torch.randn(1, 4) - >>> b + >>> c = torch.randn(1, 4) + >>> c tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) - >>> torch.mul(a, b) + >>> torch.mul(b, c) tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], [-0.1614, -0.0382, 0.1645, -0.7021], [ 0.0360, 0.0085, -0.0367, 0.1567], @@ -6706,7 +6667,7 @@ def merge_dicts(*dicts): multiply(input, other, *, out=None) Alias for :func:`torch.mul`. -""".format(**common_args)) +""") add_docstr(torch.multinomial, r""" @@ -7056,7 +7017,7 @@ def merge_dicts(*dicts): negative(input, *, out=None) -> Tensor Alias for :func:`torch.neg` -""".format(**common_args)) +""") add_docstr(torch.nextafter, r""" @@ -7429,7 +7390,7 @@ def merge_dicts(*dicts): polygamma(n, input, *, out=None) -> Tensor Alias for :func:`torch.special.polygamma`. -""".format(**common_args)) +""") add_docstr(torch.positive, r""" @@ -8288,7 +8249,7 @@ def merge_dicts(*dicts): row_stack(tensors, *, out=None) -> Tensor Alias of :func:`torch.vstack`. -""".format(**common_args)) +""") add_docstr(torch.round, r""" @@ -8977,10 +8938,10 @@ def merge_dicts(*dicts): Args: {input} - other (Tensor or Scalar): the tensor or scalar to subtract from :attr:`input` + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. Keyword args: - alpha (Scalar): the scalar multiplier for :attr:`other` + alpha (Number): the multiplier for :attr:`other`. {out} Example:: @@ -9708,7 +9669,7 @@ def merge_dicts(*dicts): .. seealso:: - :func:`torch.t` for a function that transposes tensors with <=2 dimensions. + :func:`torch.t` swaps the dimensions of two-dimensional tensors (matrices). Args: {input} @@ -9729,7 +9690,7 @@ def merge_dicts(*dicts): add_docstr(torch.triangular_solve, r""" -triangular_solve(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) +triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) Solves a system of equations with a triangular coefficient matrix :math:`A` and multiple right-hand sides :math:`b`. @@ -9756,6 +9717,10 @@ def merge_dicts(*dicts): If True, the diagonal elements of :math:`A` are assumed to be 1 and not referenced from :math:`A`. Default: ``False``. +Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + Returns: A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` @@ -10013,7 +9978,7 @@ def merge_dicts(*dicts): true_divide(dividend, divisor, *, out) -> Tensor Alias for :func:`torch.div` with ``rounding_mode=None``. -""".format(**common_args)) +""") add_docstr(torch.trunc, r""" @@ -10129,7 +10094,7 @@ def merge_dicts(*dicts): fix(input, *, out=None) -> Tensor Alias for :func:`torch.trunc` -""".format(**common_args)) +""") add_docstr(torch.unsqueeze, r""" @@ -10915,11 +10880,12 @@ def merge_dicts(*dicts): \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) \end{aligned} -When :attr:`y` is two or more dimensions, this computation is performed independently -along dimension :attr:`dim`. If :attr:`x` is also specified and is one-dimensional, -then that dimension defines the spacing for each computation. -If :attr:`x` is also specified and is not one-dimensional, then it is broadcast to -the shape of :attr:`y` and the corresponding sizes are used for each computation. +When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. +The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` +and :attr:`y`, the function computes the difference between consecutive elements along +dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have +the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. +After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. See the examples below for details. .. note:: diff --git a/torch/_utils.py b/torch/_utils.py index 210b0cde793a6..75e9075e4250f 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -173,16 +173,15 @@ def _rebuild_sparse_tensor(layout, data): raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout)) -def _rebuild_xla_tensor(data, dtype, device, requires_grad): +def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): tensor = torch.from_numpy(data).to(dtype=dtype, device=device) tensor.requires_grad = requires_grad return tensor -def _rebuild_mlc_tensor(data, dtype, device, requires_grad): - tensor = torch.from_numpy(data).to(dtype=dtype, device=device) - tensor.requires_grad = requires_grad - return tensor +# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch +_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy +_rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): diff --git a/torch/ao/sparsity/__init__.py b/torch/ao/sparsity/__init__.py index ef03c71c64732..80aa30814eac8 100644 --- a/torch/ao/sparsity/__init__.py +++ b/torch/ao/sparsity/__init__.py @@ -12,13 +12,16 @@ # Parametrizations from .sparsifier.utils import FakeSparsity +from .sparsifier.utils import module_to_fqn +from .sparsifier.utils import fqn_to_module # === Experimental === # Parametrizations from .experimental.pruner.parametrization import PruningParametrization -from .experimental.pruner.parametrization import LinearActivationReconstruction -from .experimental.pruner.parametrization import Conv2dActivationReconstruction +from .experimental.pruner.parametrization import ZeroesParametrization +from .experimental.pruner.parametrization import ActivationReconstruction +from .experimental.pruner.parametrization import BiasHook # Pruner from .experimental.pruner.base_pruner import BasePruner diff --git a/torch/ao/sparsity/experimental/pruner/README.md b/torch/ao/sparsity/experimental/pruner/README.md new file mode 100644 index 0000000000000..da0afb0bf3fb1 --- /dev/null +++ b/torch/ao/sparsity/experimental/pruner/README.md @@ -0,0 +1,93 @@ +# Intro + +The Base Pruner inherits from the Base Sparsifier. + + +# Motivation + +Sparsifying weights allows us to skip some of the multiplications during the dot product (i.e. in the Linear layers), which ultimately translates into faster inference. With structured pruning, whole rows/columns of a tensor would be zeroed-out. This translates into model transformation (not just tensor transformation). Logically, the process of structured pruning is similar to removing some of the input/output channels in the layer completely. + +![prune logic](./images/prune_1.png) + + +# Design Choices + + +## Eager Mode + +**PruningParametrization:** After pruning, the shape of the weight changes (some of the output channels are pruned). That means the output of the current layer will have less output layers compared to the original. This means that the next layer should have less input channels. + +Consider an example below: + +![prune example](./images/prune_2.png) + +The dot product of the masked matrix A (weight) and matrix B (activation) produces the zeros at the sparse locations. However, if we remove the zeros, as in the example shown earlier, the result will change: + +![prune result](./images/prune_3.png) + +The resulting matrix is of different shape (2x2 vs. 4x2). + +**Forward Hook - ActivationReconstruction **(aka re-inserting zeros): To reconstruct the activation with the original shape, we will undo the sparsification before pushing that activation to the next layer. We do this with a forward hook -- forward hooks are functions that are called on the activation after the computation is complete. + +![prune reconstruction](./images/prune_4.png) + +**Forward Hook - Bias**: + +If the layer has a bias, it must be added to the activation AFTER zeros have been re-inserted, i.e. after the `ActivationReconstruction` forward hook. + +The pruner prunes the entire channel by default (weight & corresponding bias), so indices of the bias corresponding to pruned indices will be zeroed out. + + + +# Eager Mode APIs & Code Snippets + +Supported modules: nn.Linear, nn.Conv2d, nn.BatchNorm2d* + +* when provided in `config` with corresponding Conv2d layer + +`BasePruner`: base class with abstract method `update_mask` that computes the new pruner mask for all modules (see Write Your Own Pruner). The base pruner prunes the entire channel by default (weight & corresponding bias); if you don’t want the bias to be pruned, then set `also_prune_bias` to be False. + +`prepare`: registers the pruning parametrization (called `PruningParametrization`) to each module layer of the model; also adds forward hooks for bias support and re-inserting zeros to the output so the next layer received the correct size input. + +Note: for BatchNorm2d layers, the parametrization `ZeroesParametrization` is attached instead since its weight is 1d, so removing channels would affect the input dimension as well. `ZeroesParametrization` zeroes out channels rather than removing them like `PruningParametrization`. We need this when `also_prune_bias=True`, so BatchNorm2d channels get pruned with their corresponding Conv2d channels. + + +``` +pruner = ImplementedPruner(defaults=None, also_prune_bias=True) +pruner.prepare(model, config) +``` + + +`step`: applies `update_mask` logic (i.e. prunes the weight matrix) + + +``` +pruner.step() +``` + + +`squash_mask`: applies the parametrization one last time to the weight matrix, and then removes the pruning parametrization from the model + + +``` +pruner.squash_mask() +``` + + + +# Write Your Own Pruner + +To write a custom pruner, one could inherit from the `BasePruner` and implement some of the methods. For example, if implementing a pruner that computes the mask by randomly pruning ⅓ of channels: + + +``` +class ImplementedPruner(BasePruner): + def update_mask(self, layer, **kwargs): + param = layer.parametrizations.weight[0] # PruningParametrization + all_outputs = param.original_outputs + prune = random.sample(all_outputs, len(all_outputs) // 3) + param.pruned_outputs.update(prune) +``` + + +It is the responsibility of the base class to call the `self.update_mask` when appropriate. diff --git a/torch/ao/sparsity/experimental/pruner/base_pruner.py b/torch/ao/sparsity/experimental/pruner/base_pruner.py index 075a7ceae305a..6017e8f53ae69 100644 --- a/torch/ao/sparsity/experimental/pruner/base_pruner.py +++ b/torch/ao/sparsity/experimental/pruner/base_pruner.py @@ -1,6 +1,7 @@ -import abc import copy +import warnings +import abc import torch from torch import nn @@ -8,33 +9,22 @@ from torch.nn.modules.container import ModuleDict, ModuleList -from .parametrization import PruningParametrization, LinearActivationReconstruction, Conv2dActivationReconstruction +from .parametrization import PruningParametrization, ZeroesParametrization, ActivationReconstruction, BiasHook -SUPPORTED_MODULES = { +from torch.ao.sparsity import BaseSparsifier, module_to_fqn, fqn_to_module + +SUPPORTED_MODULES = { # added to config if None given nn.Linear, - nn.Conv2d + nn.Conv2d, + nn.BatchNorm2d, # will need manual update to match conv2d } -def _module_to_path(model, layer, prefix=''): - for name, child in model.named_children(): - new_name = prefix + '.' + name - if child is layer: - return new_name - child_path = _module_to_path(child, layer, prefix=new_name) - if child_path is not None: - return child_path - return None - -def _path_to_module(model, path): - path = path.split('.') - for name in path: - model = getattr(model, name, None) - if model is None: - return None - return model - - -class BasePruner(abc.ABC): +NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed + nn.BatchNorm2d +} + + +class BasePruner(BaseSparsifier): r"""Base class for all pruners. Abstract methods that need to be implemented: @@ -43,29 +33,90 @@ class BasePruner(abc.ABC): `module_groups`. Args: - - model [nn.Module]: model to configure. The model itself is not saved - but used for the state_dict saving / loading. - - config [list]: configuration elements could either be instances of - nn.Module or dict maps. The dicts must have a key 'module' with the - value being an instance of a nn.Module. - defaults [dict]: default configurations will be attached to the configuration. Only the keys that don't exist in the `config` will be updated. + - also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel) + or not; default=True. """ - def __init__(self, model, config, defaults): - super().__init__() - self.config = config - self.defaults = defaults - if self.defaults is None: - self.defaults = dict() + def __init__(self, defaults, also_prune_bias=True): + super().__init__(defaults) + self.prune_bias = also_prune_bias - self.module_groups = [] - self.enable_mask_update = False - self.activation_handles = [] + def _prepare(self, use_path=False, *args, **kwargs): + r"""Adds mask parametrization to the layer weight + """ + self.activation_handles = [] # store removable hook handles self.bias_handles = [] - self.model = model + for config in self.module_groups: + modules = [] + if use_path: + if type(config['module']) is tuple: # (Conv2d, BN) + for fqn in config['fqn']: + module = fqn_to_module(self.model, fqn) + modules.append(module) + else: + module = fqn_to_module(self.model, config['fqn']) + modules.append(module) + else: + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + + for module in modules: + if not isinstance(module, tuple(NEEDS_ZEROS)): + # add pruning parametrization and forward hooks + if getattr(module, 'mask', None) is None: + module.register_buffer('mask', torch.tensor(module.weight.shape[0])) + param = config.get('parametrization', PruningParametrization) + parametrize.register_parametrization(module, 'weight', param(module.mask), unsafe=True) + + assert isinstance(module.parametrizations, ModuleDict) # make mypy happy + assert isinstance(module.parametrizations.weight, ModuleList) + if isinstance(module, tuple(SUPPORTED_MODULES)): + self.activation_handles.append(module.register_forward_hook( + ActivationReconstruction(module.parametrizations.weight[0]) + )) + else: + raise NotImplementedError("This module type is not supported yet.") + + else: # needs zeros + if getattr(module, 'mask', None) is None: + module.register_buffer('mask', torch.tensor(module.weight.shape[0])) + param = config.get('parametrization', ZeroesParametrization) + parametrize.register_parametrization(module, 'weight', param(module.mask), unsafe=True) + + if module.bias is not None: + module.register_parameter('_bias', nn.Parameter(module.bias.detach())) + module.bias = None + self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias))) + + if len(modules) == 2: # (Conv2d, BN) + # should have the same set of pruned outputs + modules[1].parametrizations.weight[0].pruned_outputs = modules[0].parametrizations.weight[0].pruned_outputs + + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations and forward post-hooks. + Note:: + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements could either be instances of + nn.Module or dict maps. The dicts must have a key 'module' with the + value being an instance of a nn.Module. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + # If no config -- try getting all the supported layers if self.config is None: # Add all models to the config @@ -77,108 +128,105 @@ def __init__(self, model, config, defaults): if type(child) in SUPPORTED_MODULES: self.config.append(child) else: + if type(child) in NEEDS_ZEROS and self.prune_bias: + warnings.warn(f"Models with {type(child)} layers have config provided by user.") stack.append(child) for module_config in self.config: - if isinstance(module_config, nn.Module): + if type(module_config) is tuple: + first_layer, next_layer = module_config + assert isinstance(first_layer, nn.Conv2d) and isinstance(next_layer, nn.BatchNorm2d) module_config = {'module': module_config} - local_args = copy.deepcopy(self.defaults) - local_args.update(module_config) - module = local_args['module'] - module_path = _module_to_path(self.model, module) - if module_path and module_path[0] == '.': - module_path = module_path[1:] - local_args['path'] = module_path + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + fqn_list = [] + for module in local_args['module']: + module_fqn = module_to_fqn(model, module) + if module_fqn and module_fqn[0] == '.': + module_fqn = module_fqn[1:] + fqn_list.append(module_fqn) + local_args['fqn'] = fqn_list + else: + if isinstance(module_config, nn.Module): + module_config = {'module': module_config} + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + module = local_args['module'] + module_fqn = module_to_fqn(model, module) + if module_fqn and module_fqn[0] == '.': + module_fqn = module_fqn[1:] + local_args['fqn'] = module_fqn + self.module_groups.append(local_args) - def __getstate__(self): - return { - 'defaults': self.defaults, - 'module_groups': self.module_groups, - } - - def __setstate__(self, state): - self.__dict__.update(state) - - def __repr__(self): - format_string = self.__class__.__name__ + ' (' - for i, sparse_args in enumerate(self.module_groups): - module = sparse_args['module'] - format_string += '\n' - format_string += f'\tModule Group {i}\n' - format_string += f'\t module: {module}\n' - for key in sorted(sparse_args.keys()): - if key == 'module': - continue - format_string += f'\t {key}: {sparse_args[key]}\n' - format_string += ')' - return format_string - - def bias_hook(self, module, input, output): - if getattr(module, '_bias', None) is not None: - idx = [1] * len(output.shape) - idx[1] = output.shape[1] - bias = module._bias.reshape(idx) - output += bias - return output - - def prepare(self, use_path=False, *args, **kwargs): - r"""Adds mask parametrization to the layer weight - """ + self._prepare() + + def squash_mask(self, use_path=False, *args, **kwargs): for config in self.module_groups: + modules = [] if use_path: - module = _path_to_module(self.model, config['path']) + if type(config['module']) is tuple: # (Conv2d, BN) + for fqn in config['fqn']: + module = fqn_to_module(self.model, fqn) + modules.append(module) + else: + module = fqn_to_module(self.model, config['fqn']) + modules.append(module) else: - module = config['module'] - - if getattr(module, 'mask', None) is None: - module.register_buffer('mask', torch.tensor(module.weight.shape[0])) - param = config.get('parametrization', PruningParametrization) - parametrize.register_parametrization(module, 'weight', - param(module.mask), - unsafe=True) - - assert isinstance(module.parametrizations, ModuleDict) # make mypy happy - assert isinstance(module.parametrizations.weight, ModuleList) - if isinstance(module, nn.Linear): - self.activation_handles.append(module.register_forward_hook( - LinearActivationReconstruction(module.parametrizations.weight[0]) - )) - elif isinstance(module, nn.Conv2d): - self.activation_handles.append(module.register_forward_hook( - Conv2dActivationReconstruction(module.parametrizations.weight[0]) - )) + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + + for module in modules: + parametrize.remove_parametrizations(module, 'weight', + leave_parametrized=True) + if getattr(module._parameters, 'mask', None): + del module._parameters['mask'] + elif getattr(module._buffers, 'mask', None): + del module._buffers['mask'] + delattr(module, 'mask') + + def get_module_pruned_outputs(self, module): + r"""Returns the set of pruned indices of module""" + assert parametrize.is_parametrized(module) # can only get pruned indices of pruned module + modules = {config['module'] for config in self.module_groups} + module_list = set() + for m in modules: + if type(m) is tuple: + module_list.update(m) else: - raise NotImplementedError("This module type is not supported yet.") - - if module.bias is not None: - module.register_parameter('_bias', nn.Parameter(module.bias.detach())) - module.bias = None - self.bias_handles.append(module.register_forward_hook(self.bias_hook)) + module_list.add(m) + assert module in module_list # check that module is in pruner.module_groups + return module.parametrizations.weight[0].pruned_outputs # assume only one parametrization attached - def convert(self, use_path=False, *args, **kwargs): - for config in self.module_groups: - if use_path: - module = _path_to_module(self.model, config['path']) - else: - module = config['module'] - parametrize.remove_parametrizations(module, 'weight', - leave_parametrized=True) - if getattr(module._parameters, 'mask', None): - del module._parameters['mask'] - elif getattr(module._buffers, 'mask', None): - del module._buffers['mask'] - delattr(module, 'mask') - - def step(self, use_path=True): + def step(self, use_path=False): if not self.enable_mask_update: return with torch.no_grad(): for config in self.module_groups: + modules = [] if use_path: - module = _path_to_module(self.model, config['path']) + if type(config['module']) is tuple: # (Conv2d, BN) + for fqn in config['fqn']: + module = fqn_to_module(self.model, fqn) + modules.append(module) + else: + module = fqn_to_module(self.model, config['fqn']) + modules.append(module) else: - module = config['module'] + if type(config['module']) is tuple: + for module in config['module']: + modules.append(module) + else: + module = config['module'] + modules.append(module) + + # only need to update the first module in modules if len(modules) > 1 + # since they should share the same set of pruned outputs + module = modules[0] self.update_mask(module, **config) @abc.abstractmethod diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_1.png b/torch/ao/sparsity/experimental/pruner/images/prune_1.png new file mode 100644 index 0000000000000..f7f4875922572 Binary files /dev/null and b/torch/ao/sparsity/experimental/pruner/images/prune_1.png differ diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_2.png b/torch/ao/sparsity/experimental/pruner/images/prune_2.png new file mode 100644 index 0000000000000..5aad9d0451bac Binary files /dev/null and b/torch/ao/sparsity/experimental/pruner/images/prune_2.png differ diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_3.png b/torch/ao/sparsity/experimental/pruner/images/prune_3.png new file mode 100644 index 0000000000000..1af2c3cb4ed08 Binary files /dev/null and b/torch/ao/sparsity/experimental/pruner/images/prune_3.png differ diff --git a/torch/ao/sparsity/experimental/pruner/images/prune_4.png b/torch/ao/sparsity/experimental/pruner/images/prune_4.png new file mode 100644 index 0000000000000..fe7586edc13ce Binary files /dev/null and b/torch/ao/sparsity/experimental/pruner/images/prune_4.png differ diff --git a/torch/ao/sparsity/experimental/pruner/parametrization.py b/torch/ao/sparsity/experimental/pruner/parametrization.py index 1156ea8af4ef1..0ee937a4a8ae4 100644 --- a/torch/ao/sparsity/experimental/pruner/parametrization.py +++ b/torch/ao/sparsity/experimental/pruner/parametrization.py @@ -1,5 +1,6 @@ import torch from torch import nn +from typing import Any, List class PruningParametrization(nn.Module): @@ -13,27 +14,60 @@ def forward(self, x): return x[list(valid_outputs)] -class LinearActivationReconstruction: +class ZeroesParametrization(nn.Module): + r"""Zero out pruned channels instead of removing. + E.g. used for Batch Norm pruning, which should match previous Conv2d layer.""" + def __init__(self, original_outputs): + super().__init__() + self.original_outputs = set(range(original_outputs.item())) + self.pruned_outputs = set() # Will contain indicies of outputs to prune + + def forward(self, x): + x.data[list(self.pruned_outputs)] = 0 + return x + + +class ActivationReconstruction: def __init__(self, parametrization): self.param = parametrization def __call__(self, module, input, output): max_outputs = self.param.original_outputs pruned_outputs = self.param.pruned_outputs - reconstructed_tensor = torch.zeros((output.shape[0], len(max_outputs))) valid_columns = list(max_outputs - pruned_outputs) - reconstructed_tensor[:, valid_columns] = output + + # get size of reconstructed output + sizes = list(output.shape) + sizes[1] = len(max_outputs) + + # get valid indices of reconstructed output + indices: List[Any] = [] + for size in output.shape: + indices.append(slice(0, size, 1)) + indices[1] = valid_columns + + reconstructed_tensor = torch.zeros(sizes) + reconstructed_tensor[indices] = output return reconstructed_tensor -class Conv2dActivationReconstruction: - def __init__(self, parametrization): +class BiasHook: + def __init__(self, parametrization, prune_bias): self.param = parametrization + self.prune_bias = prune_bias def __call__(self, module, input, output): - max_outputs = self.param.original_outputs pruned_outputs = self.param.pruned_outputs - reconstructed_tensor = torch.zeros((output.shape[0], len(max_outputs), output.shape[2], output.shape[3])) - valid_columns = list(max_outputs - pruned_outputs) - reconstructed_tensor[:, valid_columns, :, :] = output - return reconstructed_tensor + + if getattr(module, '_bias', None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[list(pruned_outputs)] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/torch/ao/sparsity/sparsifier/base_sparsifier.py b/torch/ao/sparsity/sparsifier/base_sparsifier.py index d6bc7d75248cf..1d01b71daae25 100644 --- a/torch/ao/sparsity/sparsifier/base_sparsifier.py +++ b/torch/ao/sparsity/sparsifier/base_sparsifier.py @@ -8,30 +8,12 @@ from torch import nn from torch.nn.utils import parametrize -from .utils import FakeSparsity +from .utils import FakeSparsity, module_to_fqn, fqn_to_module SUPPORTED_MODULES = { nn.Linear } -def _module_to_fqn(model, layer, prefix=''): - for name, child in model.named_children(): - new_name = prefix + '.' + name - if child is layer: - return new_name - child_path = _module_to_fqn(child, layer, prefix=new_name) - if child_path is not None: - return child_path - return None - -def _fqn_to_module(model, path): - path = path.split('.') - for name in path: - model = getattr(model, name, None) - if model is None: - return None - return model - class BaseSparsifier(abc.ABC): r"""Base class for all sparsifiers. @@ -136,7 +118,7 @@ def load_state_dict(self, state_dict, strict=True): module_groups = copy.deepcopy(state_dict['module_groups']) states = state_dict['state'] for fqn, s in states.items(): - layer = _fqn_to_module(self.model, fqn) + layer = fqn_to_module(self.model, fqn) if strict and layer is None: raise RuntimeError(f'Error loading {fqn} into the model') @@ -186,7 +168,7 @@ def prepare(self, model, config): local_args = copy.deepcopy(self.defaults) local_args.update(module_config) module = local_args['module'] - module_fqn = _module_to_fqn(model, module) + module_fqn = module_to_fqn(model, module) if module_fqn and module_fqn[0] == '.': module_fqn = module_fqn[1:] local_args['fqn'] = module_fqn diff --git a/torch/ao/sparsity/sparsifier/utils.py b/torch/ao/sparsity/sparsifier/utils.py index 6271a8d502f0d..3124b1b767b0f 100644 --- a/torch/ao/sparsity/sparsifier/utils.py +++ b/torch/ao/sparsity/sparsifier/utils.py @@ -1,5 +1,23 @@ from torch import nn +def module_to_fqn(model, layer, prefix=''): + for name, child in model.named_children(): + new_name = prefix + '.' + name + if child is layer: + return new_name + child_path = module_to_fqn(child, layer, prefix=new_name) + if child_path is not None: + return child_path + return None + +def fqn_to_module(model, path): + path = path.split('.') + for name in path: + model = getattr(model, name, None) + if model is None: + return None + return model + # Parametrizations class FakeSparsity(nn.Module): r"""Parametrization for the weights. Should be attached to the 'weight' or diff --git a/torch/autocast_mode.py b/torch/autocast_mode.py index edf36d25745fc..97d51b8f1ca7b 100644 --- a/torch/autocast_mode.py +++ b/torch/autocast_mode.py @@ -80,7 +80,7 @@ def forward(self, input): c_float32 = torch.rand((8, 8), device="cpu") d_float32 = torch.rand((8, 8), device="cpu") - with autocast(fast_dtype=torch.bfloat16, device_type="cpu"): + with autocast(dtype=torch.bfloat16, device_type="cpu"): # torch.mm is on autocast's list of ops that should run in bfloat16. # Inputs are float32, but the op runs in bfloat16 and produces bfloat16 output. # No manual casts are required. @@ -125,7 +125,7 @@ def forward(self, input): Args: device_type(string, required): Whether to use 'cuda' or 'cpu' device enabled(bool, optional, default=True)": Whether autocasting should be enabled in the region. - fast_dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16 + dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16 """ def __init__(self, device_type, enabled=True, **kwargs): self.device = device_type @@ -135,13 +135,13 @@ def __init__(self, device_type, enabled=True, **kwargs): self.fast_dtype = torch.get_autocast_cpu_dtype() else: raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'') - if not torch.cuda.is_available() and self.device == 'cuda': + if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') enabled = False for key, value in kwargs.items(): - if key == 'fast_dtype': + if key == 'dtype': self.fast_dtype = value - if not (key == 'fast_dtype'): + if not (key == 'dtype'): raise RuntimeError('Unrecognized optional argument supplied to autocast context manager: ' + str(key)) if self.device == 'cpu': @@ -152,8 +152,8 @@ def __init__(self, device_type, enabled=True, **kwargs): warnings.warn(error_message) enabled = False if self.device == 'cuda': - if self.fast_dtype == torch.bfloat16 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: - raise RuntimeError('Current CUDA Device does not support bfloat16. Switching fast_dtype to float16.') + if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') self._enabled = enabled def __enter__(self): diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 0d4f153d007c1..d11e261efcea1 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -173,17 +173,18 @@ def grad( gradients w.r.t. each of the outputs. If an output doesn't require_grad, then the gradient can be ``None``). - If ``only_inputs`` is ``True``, the function will only return a list of gradients - w.r.t the specified inputs. If it's ``False``, then gradient w.r.t. all remaining - leaves will still be computed, and will be accumulated into their ``.grad`` - attribute. - .. note:: If you run any forward ops, create ``grad_outputs``, and/or call ``grad`` in a user-specified CUDA stream context, see :ref:`Stream semantics of backward passes`. + .. note:: + + ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``). + To accumulate gradient for other parts of the graph, please use + ``torch.autograd.backward``. + Args: outputs (sequence of Tensor): outputs of the differentiated function. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 4fc25c5951d11..909e71959320b 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -8,24 +8,53 @@ from collections import OrderedDict from typing import Any, List, Optional +# Formerly known as: _ContextMethodMixin +class FunctionCtx(object): -class _ContextMethodMixin(object): - - def save_for_backward(self, *tensors): + def save_for_backward(self, *tensors: torch.Tensor): r"""Saves given tensors for a future call to :func:`~Function.backward`. **This should be called at most once, and only from inside the** - :func:`forward` **method.** + :func:`forward` **method. This should only be called with input or + output tensors** - Later, saved tensors can be accessed through the :attr:`saved_tensors` + In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors` attribute. Before returning them to the user, a check is made to ensure they weren't used in any in-place operation that modified their content. - Arguments can also be ``None``. + Arguments can also be ``None``. This is a no-op. + + See :ref:`extending-autograd` for more details on how to use this method. + + Example:: + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): + >>> w = x * y * z + >>> out = x * y + y * z + w + >>> ctx.save_for_backward(x, y, out) + >>> ctx.z = z # z is not a tensor + >>> ctx.w = w # w is neither input nor output + >>> return out + >>> + >>> @staticmethod + >>> def backward(ctx, grad_out): + >>> x, y, out = ctx.saved_tensors + >>> z = ctx.z + >>> gx = grad_out * (y + y * z) + >>> gy = grad_out * (x + z + x * z) + >>> gz = None + >>> return gx, gy, gz + >>> + >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) + >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) + >>> c = 4 + >>> d = Func.apply(a, b, c) + """ self.to_save = tensors - def mark_dirty(self, *args): + def mark_dirty(self, *args: torch.Tensor): r"""Marks given tensors as modified in an in-place operation. **This should be called at most once, only from inside the** @@ -35,6 +64,28 @@ def mark_dirty(self, *args): should be given to this function, to ensure correctness of our checks. It doesn't matter whether the function is called before or after modification. + + Examples:: + >>> class Inplace(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> x_npy = x.numpy() # x_npy shares storage with x + >>> x_npy += 1 + >>> ctx.mark_dirty(x) + >>> return x + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, grad_output): + >>> return grad_output + >>> + >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() + >>> b = a * a + >>> Inplace.apply(a) # This would lead to wrong gradients! + >>> # but the engine would not know unless we mark_dirty + >>> b.backward() # RuntimeError: one of the variables needed for gradient + >>> # computation has been modified by an inplace operation + """ self.dirty_tensors = args @@ -44,11 +95,11 @@ def mark_shared_storage(self, *pairs): 'Tensors with shared storages are automatically tracked. Note ' 'that calls to `set_()` are not tracked') - def mark_non_differentiable(self, *args): + def mark_non_differentiable(self, *args: torch.Tensor): r"""Marks outputs as non-differentiable. **This should be called at most once, only from inside the** - :func:`forward` **method, and all arguments should be outputs.** + :func:`forward` **method, and all arguments should be tensor outputs.** This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient @@ -56,20 +107,73 @@ def mark_non_differentiable(self, *args): be a zero tensor with the same shape as the shape of a corresponding output. - This is used e.g. for indices returned from a max :class:`Function`. + This is used e.g. for indices returned from a sort. See example:: + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> sorted, idx = x.sort() + >>> ctx.mark_non_differentiable(idx) + >>> ctx.save_for_backward(x, idx) + >>> return sorted, idx + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): # still need to accept g2 + >>> x, idx = ctx.saved_tensors + >>> grad_input = torch.zeros_like(x) + >>> grad_input.index_add_(0, idx, g1) + >>> return grad_input + """ self.non_differentiable = args - def set_materialize_grads(self, value): - r"""Sets whether to materialize output grad tensors. Default is true. + def set_materialize_grads(self, value: bool): + r"""Sets whether to materialize output grad tensors. Default is ``True``. **This should be called only from inside the** :func:`forward` **method** - If true, undefined output grad tensors will be expanded to tensors full + If ``True``, undefined output grad tensors will be expanded to tensors full of zeros prior to calling the :func:`backward` method. + + Example:: + >>> class SimpleFunc(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> return x.clone(), x.clone() + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): + >>> return g1 + g2 # No check for None necessary + >>> + >>> # We modify SimpleFunc to handle non-materialized grad outputs + >>> class Func(Function): + >>> @staticmethod + >>> def forward(ctx, x): + >>> ctx.set_materialize_grads(False) + >>> ctx.save_for_backward(x) + >>> return x.clone(), x.clone() + >>> + >>> @staticmethod + >>> @once_differentiable + >>> def backward(ctx, g1, g2): + >>> x, = ctx.saved_tensors + >>> grad_input = torch.zeros_like(x) + >>> if g1 is not None: # We must check for None now + >>> grad_input += g1 + >>> if g2 is not None: + >>> grad_input += g2 + >>> return grad_input + >>> + >>> a = torch.tensor(1., requires_grad=True) + >>> b, _ = Func.apply(a) # induces g2 to be undefined + """ self.materialize_grads = value +# DO NOT USE: This is only defined to be able to load old serialized models +_ContextMethodMixin = FunctionCtx + class _HookMixin(object): @staticmethod @@ -81,10 +185,22 @@ def _register_hook(backward_hooks, hook): return backward_hooks, handle -class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): +class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): def apply(self, *args): # _forward_cls is defined by derived class - return self._forward_cls.backward(self, *args) # type: ignore[attr-defined] + # The user should define either backward or vjp but never both. + backward_fn = self._forward_cls.backward # type: ignore[attr-defined] + vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] + if backward_fn is not Function.backward and vjp_fn is not Function.vjp: + raise RuntimeError("Implementing both 'backward' and 'vjp' for a custom " + "Function is not allowed. You should only implement one " + "of them.") + user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn + return user_fn(self, *args) + + def apply_jvp(self, *args): + # _forward_cls is defined by derived class + return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined] class FunctionMeta(type): @@ -103,28 +219,23 @@ def __init__(cls, name, bases, attrs): # mypy doesn't understand `with_metaclass` from torch._six -class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore[misc] - r"""Records operation history and defines formulas for differentiating ops. +class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc] + r"""Base class to create custom `autograd.Function` - See the Note on extending the autograd engine for more details on how to use - this class: https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd + To create a custom `autograd.Function`, subclass this class and implement + the :meth:`forward` and :meth`backward` static methods. Then, to use your custom + op in the forward pass, call the class method ``apply``. Do not call + :meth:`forward` directly. - Every operation performed on :class:`Tensor` s creates a new function - object, that performs the computation, and records that it happened. - The history is retained in the form of a DAG of functions, with edges - denoting data dependencies (``input <- output``). Then, when backward is - called, the graph is processed in the topological ordering, by calling - :func:`backward` methods of each :class:`Function` object, and passing - returned gradients on to next :class:`Function` s. + To ensure correctness and best performance, make sure you are calling the + correct methods on ``ctx`` and validating your backward function using + :func:`torch.autograd.gradcheck`. - Normally, the only way users interact with functions is by creating - subclasses and defining new operations. This is a recommended way of - extending torch.autograd. + See :ref:`extending-autograd` for more details on how to use this class. Examples:: >>> class Exp(Function): - >>> >>> @staticmethod >>> def forward(ctx, i): >>> result = i.exp() @@ -136,7 +247,7 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi >>> result, = ctx.saved_tensors >>> return grad_output * result >>> - >>> #Use it by calling the apply method: + >>> # Use it by calling the apply method: >>> output = Exp.apply(input) """ def __init__(self, *args, **kwargs): @@ -172,7 +283,8 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: - r"""Defines a formula for differentiating the operation. + r"""Defines a formula for differentiating the operation with backward mode + automatic differentiation. This function is to be overridden by all subclasses. @@ -192,9 +304,33 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: first input to :func:`forward` needs gradient computated w.r.t. the output. """ - raise NotImplementedError("You must implement the backward function for custom" - " autograd.Function.") + raise NotImplementedError("You must implement either the backward or vjp method for " + "your custom autograd.Function to use it with backward " + "mode AD.") + + # vjp and backward are alias of each other + vjp = backward + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + r"""Defines a formula for differentiating the operation with forward mode + automatic differentiation. + This function is to be overridden by all subclasses. + It must accept a context :attr:`ctx` as the first argument, followed by + as many inputs as the :func:`forward` got (None will be passed in + for non tensor inputs of the forward function), + and it should return as many tensors as there were outputs to + :func:`forward`. Each argument is the gradient w.r.t the given input, + and each returned value should be the gradient w.r.t. the + corresponding output. If an output is not a Tensor or the function is not + differentiable with respect to that output, you can just pass None as a + gradient for that input. + + You can use the :attr:`ctx` object to pass any value from the forward to this + functions. + """ + raise NotImplementedError("You must implement the jvp function for custom " + "autograd.Function to use it with forward mode AD.") def once_differentiable(fn): @@ -224,7 +360,7 @@ def wrapper(ctx, *args): outputs = (outputs,) err_fn = _functions.DelayedError( - b"trying to differentiate twice a function that was marked" + b"trying to differentiate twice a function that was marked " b"with @once_differentiable", len(outputs)) # Create aliases of each output that has requires_grad=True. We need diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index f85d51a040995..a2530df478833 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -404,8 +404,8 @@ def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): self.compute_operand_shape(operand_id, dim, f"args[{arg_idx}].shape[{dim}]") return operand_id - def add_tensor_operand_for_weight(self, tensor): - toper = self.torch_tensor_to_operand(tensor, DimOrder.UNKNOWN_CONSTANT) + def add_tensor_operand_for_weight(self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT): + toper = self.torch_tensor_to_operand(tensor, dim_order) operand_id = len(self.operands) self.operands.append(toper) tsize = tensor_size(toper.op_type, toper.shape) @@ -418,6 +418,9 @@ def add_tensor_operand_for_weight(self, tensor): buf_num, offset, tsize)) + # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor + if dim_order == DimOrder.CHANNELS_LAST: + tensor = tensor.permute(0, 2, 3, 1) self.used_weights.append(tensor) return operand_id @@ -456,6 +459,9 @@ def add_immediate_int_vector(self, value): array.array("i", value).tobytes(), (len(value),)) + def has_operand_for_jitval(self, jitval): + return jitval in self.jitval_operand_map + def get_tensor_operand_by_jitval(self, jitval): operand_id = self.jitval_operand_map[jitval] return (operand_id, self.operands[operand_id]) @@ -469,11 +475,11 @@ def get_tensor_operand_by_jitval_fixed_size(self, jitval): raise Exception("Flexible size is not supported for this operand.") return op_id, oper - def get_tensor_operand_or_constant(self, jitval): + def get_tensor_operand_or_constant(self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS): operand_id = self.jitval_operand_map.get(jitval) if operand_id is None: _, value = self.get_constant_value(jitval, "TensorType") - operand_id = self.add_tensor_operand_for_weight(value) + operand_id = self.add_tensor_operand_for_weight(value, dim_order) return (operand_id, self.operands[operand_id]) def get_tensor_operand_for_weight(self, jitval): @@ -1233,9 +1239,14 @@ def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): assert node.inputsAt(0).type().kind() == "TensorType" assert node.inputsAt(1).type().kind() == "TensorType" - # TODO: Should support constant as either operand. - in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) - in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + if self.has_operand_for_jitval(node.inputsAt(0)): + in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in1_id, in1_oper = self.get_tensor_operand_or_constant(node.inputsAt(1), in0_oper.dim_order) + elif self.has_operand_for_jitval(node.inputsAt(1)): + in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + in0_id, in0_oper = self.get_tensor_operand_or_constant(node.inputsAt(0), in1_oper.dim_order) + else: + raise Exception(f"Can't do a NNAPI binary op: {opcode} on two constants") assert in0_oper.op_type == in1_oper.op_type in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index 027ef382f1599..8c65f727753e2 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -5,5 +5,5 @@ class autocast(torch.autocast_mode.autocast): See :class:`torch.autocast`. ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)`` """ - def __init__(self, enabled=True, fast_dtype=torch.float16): - super().__init__("cpu", enabled=enabled, fast_dtype=fast_dtype) + def __init__(self, enabled=True, dtype=torch.bfloat16): + super().__init__("cpu", enabled=enabled, dtype=dtype) diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp index 6b42ca078024a..9033d445081ea 100644 --- a/torch/csrc/CudaIPCTypes.cpp +++ b/torch/csrc/CudaIPCTypes.cpp @@ -19,24 +19,27 @@ void warnProducerTerminatedBeforeSharedTensorsReleased() { } struct CudaIPCGlobalEntities { + // This class is used as a singleton (see cuda_ipc_global_entities) + // This variable is used to track its lifetime to avoid accessing it + // after it was destroyed which would lead to segmentation faults + // Note that a trvial type is used which doesn't suffer from construction + // and destruction order issues + static bool alive; + std::mutex ref_counters_mutex_; std::atomic sync_events_used_{0}; std::map> ref_counters_files_; std::shared_ptr next_available_ref_counters_file_; CudaIPCSentDataLimbo CudaIPCSentDataLimbo_; - CudaIPCGlobalEntities() = default; + CudaIPCGlobalEntities() { alive = true; } ~CudaIPCGlobalEntities() { CudaIPCSentDataLimbo_.collect(); - // Clear shared blocks to avoid releasing shared blocks after - // ~CudaIPCGlobalEntities is done since circular references causes the - // destructor of ~CudaIPCSentData to access the cuda_ipc_global_entities - // again. - CudaIPCSentDataLimbo_.clear_shared_blocks(); safe_clean_current_file(); if (next_available_ref_counters_file_) { warnProducerTerminatedBeforeSharedTensorsReleased(); } + alive = false; } void safe_clean_current_file() { std::lock_guard lock(ref_counters_mutex_); @@ -48,19 +51,16 @@ struct CudaIPCGlobalEntities { } }; +bool CudaIPCGlobalEntities::alive = false; CudaIPCGlobalEntities cuda_ipc_global_entities; CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() { collect(); - if (shared_blocks_.size() > 0) { + if (size() > 0) { warnProducerTerminatedBeforeSharedTensorsReleased(); } } -void CudaIPCSentDataLimbo::clear_shared_blocks() { - shared_blocks_.clear(); -} - bool CudaIPCSentDataLimbo::collect() { bool freed_memory = false; std::vector> reset_blocks; @@ -99,9 +99,17 @@ void CudaIPCSentDataLimbo::add(std::unique_ptr shared_block) { shared_blocks_.push_back(std::move(shared_block)); } +uint64_t CudaIPCSentDataLimbo::size() { + std::lock_guard lock(limbo_mutex_); + return shared_blocks_.size(); +} + void CudaIPCSentDataDelete(void* ptr) { std::unique_ptr sent_data( static_cast(ptr)); + if(!CudaIPCGlobalEntities::alive) { + return; + } if (sent_data->counter_value() > 0) { cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data)); } @@ -109,6 +117,9 @@ void CudaIPCSentDataDelete(void* ptr) { } void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) { + if(!CudaIPCGlobalEntities::alive) { + return; + } std::lock_guard lock( cuda_ipc_global_entities.ref_counters_mutex_); auto& map = cuda_ipc_global_entities.ref_counters_files_; @@ -180,6 +191,9 @@ CudaIPCSentData::~CudaIPCSentData() { if (event_sync_required_) { at::cuda::CUDAGuard device_guard(device_.index()); cudaEventDestroy(event_); + if(!CudaIPCGlobalEntities::alive) { + return; + } cuda_ipc_global_entities.sync_events_used_ --; } } catch (...) { /* No throw */ @@ -226,6 +240,9 @@ at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) { } bool CudaIPCCollect() { + if(!CudaIPCGlobalEntities::alive) { + return true; + } bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect(); if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) { cuda_ipc_global_entities.safe_clean_current_file(); diff --git a/torch/csrc/CudaIPCTypes.h b/torch/csrc/CudaIPCTypes.h index 63e1d1d416a5a..ab9ede006916d 100644 --- a/torch/csrc/CudaIPCTypes.h +++ b/torch/csrc/CudaIPCTypes.h @@ -63,11 +63,8 @@ constexpr int64_t CUDA_IPC_MAXIMUM_EVENTS_TO_USE = 1000; struct CudaIPCSentDataLimbo final { ~CudaIPCSentDataLimbo(); bool collect(); - void clear_shared_blocks(); void add(std::unique_ptr shared_block); - uint64_t size() { - return shared_blocks_.size(); - } + uint64_t size(); private: // TODO: Can be changed to FIFO in order to avoid full traverse on every diff --git a/torch/csrc/Device.h b/torch/csrc/Device.h index b1f18dcebd1ab..32868120c06a1 100644 --- a/torch/csrc/Device.h +++ b/torch/csrc/Device.h @@ -17,6 +17,6 @@ inline bool THPDevice_Check(PyObject *obj) { return Py_TYPE(obj) == &THPDeviceType; } -PyObject * THPDevice_New(const at::Device& device); +TORCH_API PyObject * THPDevice_New(const at::Device& device); -void THPDevice_init(PyObject *module); +TORCH_API void THPDevice_init(PyObject *module); diff --git a/torch/csrc/api/include/torch/autograd.h b/torch/csrc/api/include/torch/autograd.h index 83aa102de0128..809fbe8bd3350 100644 --- a/torch/csrc/api/include/torch/autograd.h +++ b/torch/csrc/api/include/torch/autograd.h @@ -2,3 +2,4 @@ #include #include +#include diff --git a/torch/csrc/api/include/torch/imethod.h b/torch/csrc/api/include/torch/imethod.h index dfabf50ce7191..5ab9b83888214 100644 --- a/torch/csrc/api/include/torch/imethod.h +++ b/torch/csrc/api/include/torch/imethod.h @@ -4,7 +4,7 @@ namespace torch { -class IMethod { +class TORCH_API IMethod { /* IMethod provides a portable interface for torch methods, whether they are backed by torchscript or python/deploy. @@ -28,6 +28,8 @@ class IMethod { std::vector args, const IValueMap& kwargs = IValueMap()) const = 0; + virtual const std::string& name() const = 0; + // Returns an ordered list of argument names, possible in both // script and python methods. This is a more portable dependency // than a ScriptMethod FunctionSchema, which has more information diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index ea2f6066ddf15..1fa91ad6deb1f 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -824,13 +824,15 @@ inline Tensor cross_entropy( const Tensor& target, const Tensor& weight, int64_t ignore_index, - CrossEntropyFuncOptions::reduction_t reduction) { + CrossEntropyFuncOptions::reduction_t reduction, + double label_smoothing) { return torch::cross_entropy_loss( input, target, weight, enumtype::reduction_get_enum(reduction), - ignore_index); + ignore_index, + label_smoothing); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -855,7 +857,8 @@ inline Tensor cross_entropy( target, options.weight(), options.ignore_index(), - options.reduction()); + options.reduction(), + options.label_smoothing()); } // ============================================================================ diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h index c8538858e8a74..f06b68ba2870d 100644 --- a/torch/csrc/api/include/torch/nn/functional/pooling.h +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -776,7 +776,7 @@ inline std::tuple fractional_max_pool2d_with_indices( Tensor _random_samples_ = _random_samples; if (!_random_samples_.defined()) { - auto n_batch = 1 ? input.dim() == 3 : input.size(0); + auto n_batch = input.dim() == 3; _random_samples_ = torch::rand({n_batch, input.size(-1), 2}, torch::TensorOptions().dtype(input.dtype()).device(input.device())); } return torch::fractional_max_pool2d(input, kernel_size, *output_size_, _random_samples_); diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index d8ffd15c8660a..1479de571d13e 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -662,6 +662,8 @@ struct TORCH_API CrossEntropyLossOptions { TORCH_ARG(int64_t, ignore_index) = -100; /// Specifies the reduction to apply to the output. Default: Mean TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Specifies the amount of smoothing when computing the loss. Default: 0.0 + TORCH_ARG(double, label_smoothing) = 0.0; }; namespace functional { diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 3c4d2b8c98f50..e724a75c58ec9 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -170,8 +170,8 @@ void Softmax2dImpl::pretty_print(std::ostream& stream) const { } Tensor Softmax2dImpl::forward(const Tensor& input) { - TORCH_CHECK(input.dim() == 4, "Softmax2d requires a 4D tensor as input"); - return F::detail::softmax(input, /*dim=*/1, c10::nullopt); + TORCH_CHECK(input.dim() == 4 || input.dim() == 3, "Softmax2d requires a 3D or 4D tensor as input"); + return F::detail::softmax(input, /*dim=*/-3, c10::nullopt); } // ============================================================================ diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp index d5d8c687168e8..dda67fe9c728e 100644 --- a/torch/csrc/api/src/nn/modules/loss.cpp +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -378,7 +378,8 @@ Tensor CrossEntropyLossImpl::forward( target, weight, options.ignore_index(), - options.reduction()); + options.reduction(), + options.label_smoothing()); } // ============================================================================ diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 86639c13ea678..95170f073fc38 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1,5 +1,7 @@ #include #include +#include + #include #include @@ -44,10 +46,6 @@ bool isDefined(const c10::optional& t) { return t.has_value() && t->defined(); } -bool isFwGradDefined(const c10::optional& t) { - return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined(); -} - Tensor toNonOptTensor(const c10::optional& t) { return t.has_value() ? *t : Tensor(); } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index d397f55d15189..31a972e3f3280 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -31,7 +31,6 @@ struct IndexRangeGenerator { size_t i = 0; }; -bool isFwGradDefined(const c10::optional& t); Tensor toNonOptFwGrad(const c10::optional& t); Tensor toNonOptPrimal(const c10::optional& t); Tensor toNonOptTensor(const c10::optional& t); diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index f409daa9b83d6..25f05fc110177 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -100,7 +101,7 @@ Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor & self, int64_t level) { if (grad_fn) { set_history(flatten_tensor_args( result ), grad_fn); } - if (generated::details::isFwGradDefined(self)) { + if (isFwGradDefined(self)) { // Modified from original codegen // We explicitly want to ignore the forward grad at the given level TORCH_CHECK(level == 0, "Invalid level given to _fw_primal"); @@ -131,7 +132,7 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n rebase_history(self , std::move(grad_fn)); if (isDifferentiableType(self.scalar_type()) && - (generated::details::isFwGradDefined(self) || generated::details::isFwGradDefined(src))) { + (isFwGradDefined(self) || isFwGradDefined(src))) { auto self_fw_grad = generated::details::toNonOptFwGrad(self); auto src_fw_grad = generated::details::toNonOptFwGrad(src); Tensor new_fw_grad; diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index bde2dc46352da..977e9e4cecd5c 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -1,14 +1,12 @@ #pragma once #include -#include #include #include #include #include #include -#include #include #include #include @@ -35,9 +33,6 @@ #endif #endif -using namespace at; -using namespace torch::autograd::generated; - namespace torch { namespace autograd { // The requires_grad argument is used to know if the inplace operation needs @@ -47,7 +42,7 @@ namespace torch { namespace autograd { // a = torch.rand(2) // b = torch.rand(2, requires_grad=True) // a.copy_(b) -inline void check_inplace(const Tensor& tensor, bool requires_grad) { +inline void check_inplace(const at::Tensor& tensor, bool requires_grad) { if (requires_grad && GradMode::is_enabled()) { auto diff_view_meta = impl::get_view_autograd_meta(tensor); if (diff_view_meta && diff_view_meta->has_bw_view()) { @@ -65,7 +60,7 @@ inline void check_inplace(const Tensor& tensor, bool requires_grad) { } } -inline void check_inplace(const TensorList tensors, bool requires_grad) { +inline void check_inplace(const at::TensorList tensors, bool requires_grad) { for (const auto& tensor : tensors) { check_inplace(tensor, requires_grad); } @@ -77,14 +72,14 @@ inline void throw_error_out_requires_grad(const char* name) { "but one of the arguments requires grad."); } -inline void throw_error_for_complex_autograd(const Tensor& tensor, const char* name) { +inline void throw_error_for_complex_autograd(const at::Tensor& tensor, const char* name) { if (tensor.requires_grad()) { TORCH_CHECK(!tensor.is_complex(), name, " does not support automatic differentiation for outputs with complex dtype."); } } -inline void throw_error_for_complex_autograd(const TensorList& tensorlist, const char* name) { +inline void throw_error_for_complex_autograd(const at::TensorList& tensorlist, const char* name) { for (const auto& tensor: tensorlist) { throw_error_for_complex_autograd(tensor, name); } @@ -114,7 +109,7 @@ inline void rebase_history(std::vector&& vars, std::shared_ptr g } } -inline void increment_version(const Tensor & t) { +inline void increment_version(const at::Tensor & t) { impl::bump_version(t); } @@ -138,8 +133,8 @@ template inline variable_list flatten_tensor_args(Args&&... ar } // See NOTE [ Autograd View Variables ] for details. -inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable, - bool is_fw_differentiable, std::function view_func=nullptr, +inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bool is_bw_differentiable, + bool is_fw_differentiable, std::function view_func=nullptr, CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) { // Note [View of inference tensor] // For inference tensor this code can only be hit outside InferenceMode @@ -202,7 +197,7 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif } // See NOTE [ Autograd View Variables ] for details. -inline std::vector as_view(const Tensor & base, std::vector& tensors, bool is_bw_differentiable, +inline std::vector as_view(const at::Tensor & base, std::vector& tensors, bool is_bw_differentiable, bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) { // See Note [View of inference tensor] if (base.is_inference()) return tensors; @@ -228,7 +223,7 @@ inline std::vector as_view(const Tensor & base, std::vector& ten new_shared_info = ViewInfo(base, /* view_func */ nullptr); } - for(Tensor &tensor : tensors) { + for(at::Tensor &tensor : tensors) { if (is_fw_differentiable || is_bw_differentiable) { tensor = make_variable_differentiable_view(tensor, new_shared_info, c10::nullopt, /*shared_view_info*/ true, creation_meta); } else { @@ -282,7 +277,7 @@ inline std::vector as_view(const Tensor & base, std::vector& ten creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); } - for(Tensor &tensor : tensors) { + for(at::Tensor &tensor : tensors) { if (is_fw_differentiable || is_bw_differentiable) { tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, /*shared_view_info*/ false, creation_meta); } else { @@ -292,20 +287,20 @@ inline std::vector as_view(const Tensor & base, std::vector& ten return tensors; } -inline void check_no_requires_grad(const Tensor& tensor, const char* name, +inline void check_no_requires_grad(const at::Tensor& tensor, const char* name, const char* fn_name="", bool check_grad_mode=true) { TORCH_CHECK(!(tensor.defined() && tensor.requires_grad()) || !(check_grad_mode && GradMode::is_enabled()), "The function '", fn_name, "' is not differentiable with respect to argument '", name, "'. This input cannot have requires_grad True."); } -inline void check_no_requires_grad(const c10::optional& tensor, const char* name, const char* fn_name="") { +inline void check_no_requires_grad(const c10::optional& tensor, const char* name, const char* fn_name="") { if (tensor.has_value()) { check_no_requires_grad(*tensor, name, fn_name); } } -inline void check_no_requires_grad(TensorList tensors, const char* name, const char* fn_name="") { +inline void check_no_requires_grad(at::TensorList tensors, const char* name, const char* fn_name="") { // GradMode check is expensive, so check it only once for TensorLists if (!GradMode::is_enabled()) { return; @@ -315,12 +310,12 @@ inline void check_no_requires_grad(TensorList tensors, const char* name, const c } } -inline void check_no_requires_grad(const c10::List>& tensors, const char* name, const char* fn_name="") { +inline void check_no_requires_grad(const c10::List>& tensors, const char* name, const char* fn_name="") { // GradMode check is expensive, so check it only once for TensorLists if (!GradMode::is_enabled()) { return; } - for (c10::optional tensor : tensors) { + for (c10::optional tensor : tensors) { if (tensor.has_value()) { check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false); } @@ -328,23 +323,23 @@ inline void check_no_requires_grad(const c10::List>& tenso } // Assumed that saved tensor lists are never inplace outputs -inline std::vector make_saved_variable_list(TensorList tensors) { - return fmap(tensors, [](const Tensor& tensor) -> SavedVariable { +inline std::vector make_saved_variable_list(at::TensorList tensors) { + return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable { return SavedVariable{tensor, false /* is output */}; }); } // Assumed that saved tensor lists are never inplace outputs inline std::vector make_saved_variable_list(const c10::List>& tensors) { - return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { + return fmap(tensors, [](const c10::optional& tensor) -> SavedVariable { if (tensor.has_value()) { return SavedVariable{*tensor, false /* is output */}; } else { - return SavedVariable{Tensor(), false /* is output */}; + return SavedVariable{at::Tensor(), false /* is output */}; } }); } -inline std::vector> to_args_sizes(TensorList tensors) { +inline std::vector> to_args_sizes(at::TensorList tensors) { std::vector> args_sizes(tensors.size()); for (const auto i : c10::irange(tensors.size())) { args_sizes[i] = tensors[i].sizes().vec(); @@ -352,11 +347,12 @@ inline std::vector> to_args_sizes(TensorList tensors) { return args_sizes; } -inline std::vector to_args_scalartypes(TensorList tensors) { - std::vector args_scalartypes(tensors.size()); +inline std::vector to_args_scalartypes(at::TensorList tensors) { + std::vector args_scalartypes(tensors.size()); for (const auto i : c10::irange(tensors.size())) { args_scalartypes[i] = tensors[i].scalar_type(); } return args_scalartypes; } + }} // namespace torch::autograd diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index 248847f66ca60..f35c122225831 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -183,6 +183,12 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, } const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const { + // TLS that disables forward AD + // This is only used for custom Function implementation + if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) { + return ForwardGrad::undef_grad(); + } + // Ensure that concurent fw_grad() "reads" are thread safe std::lock_guard lock(mutex_); diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp new file mode 100644 index 0000000000000..ab9cb49ec63a7 --- /dev/null +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -0,0 +1,189 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch { namespace autograd { + +namespace { + +template +void _foreach_tensor( + F fn, + torch::jit::Stack* stack, + size_t stack_start, + size_t size) { + // Enumerate over tensors in a stack, including ones in TensorLists + int idx_tensor = 0; + for (const auto idx_arg : c10::irange(size)) { + auto& ivalue = (*stack)[stack_start + idx_arg]; + if (ivalue.isTensor()) { // true for optional tensor that has value + const auto& tensor = ivalue.toTensor(); + fn(idx_tensor, idx_arg, tensor); + idx_tensor++; + } else if (ivalue.isTensorList()) { + for (const auto& iv : ivalue.toListRef()) { + const auto& tensor = iv.toTensor(); + fn(idx_tensor, idx_arg, tensor); + idx_tensor++; + } + } + } +} + +} + +void autogradNotImplementedFallbackImpl(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { + // Mimics the logic of a VariableType NotImplemented kernel + const auto& schema = op.schema(); + const auto& op_name = schema.operator_name().name; + const auto& arguments = schema.arguments(); + const auto& returns = schema.returns(); + const auto num_arguments = arguments.size(); + const auto num_returns = returns.size(); + const auto stack_start = stack->size() - num_arguments; + const bool grad_mode = GradMode::is_enabled(); + std::vector tensors_requiring_grad_on_stack; + + // Keep track of which outputs are output of in-place modification + // so we can rebase_history if necessary + std::vector is_inplace_output; + bool any_is_inplace_output = false; + std::vector is_aliased_output; + is_inplace_output.reserve(num_returns); + is_aliased_output.reserve(num_returns); + + for (const auto i : c10::irange(num_returns)) { + const auto& alias_info = returns[i].alias_info(); + is_inplace_output.push_back(alias_info.has_value() && alias_info->isWrite()); + any_is_inplace_output |= alias_info.has_value() && alias_info->isWrite(); + is_aliased_output.push_back(alias_info.has_value()); + + } + int aliased_input_idx = -1; + int aliased_output_idx = -1; + for (const auto i : c10::irange(num_returns)) { + const auto& alias_info = returns[i].alias_info(); + if (alias_info.has_value() && !alias_info->isWrite()) { + AT_ASSERT( + aliased_output_idx == -1, + "Expected only a single output in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple outputs are aliased with inputs aren't supported." + "Please rewrite your function as a composite function."); + aliased_output_idx = i; + } + } + for (const auto i : c10::irange(num_arguments)) { + const auto& alias_info = arguments[i].alias_info(); + if (alias_info.has_value() && !alias_info->isWrite()) { + AT_ASSERT( + aliased_input_idx == -1, + "Expected only a single input in the operator schema to have a non-write alias annotation (i.e., 'Tensor(a)'). " + "Non-composite functions where multiple inputs are aliased with outputs aren't supported. " + "Please rewrite your function as a composite function."); + aliased_input_idx = i; + } + } + + size_t num_tensor_inputs = 0; // Only used for DEBUG-only checks + + _foreach_tensor([&](size_t _, size_t idx_arg, const at::Tensor& t) { + if (grad_mode && t.requires_grad()) { + tensors_requiring_grad_on_stack.push_back(&t); + } + num_tensor_inputs++; + TORCH_CHECK_NOT_IMPLEMENTED(!isFwGradDefined(t), "Trying to use forward AD with ", op_name, " that does not support it."); + }, stack, stack_start, num_arguments); + + const bool any_requires_grad = tensors_requiring_grad_on_stack.size() > 0; + + _foreach_tensor([&](size_t _, size_t i, const at::Tensor& t) { + const auto& alias_info = arguments[i].alias_info(); + if (alias_info.has_value() && alias_info->isWrite()) { + check_inplace(t, any_requires_grad); + } + }, stack, stack_start, num_arguments); + + std::shared_ptr grad_fn; + if (any_requires_grad) { + grad_fn = std::shared_ptr(new NotImplemented(op_name), deleteNode); + grad_fn->set_next_edges(collect_next_edges(tensors_requiring_grad_on_stack)); + } + + #ifndef NDEBUG + // See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] + auto stack_args_copy = std::vector(stack->begin() + stack_start, stack->end()); + std::vector> impl_saved; + impl_saved.reserve(num_tensor_inputs); + std::vector> storage_saved; + storage_saved.reserve(num_tensor_inputs); + _foreach_tensor([&](size_t idx, size_t _, const at::Tensor& t) { + storage_saved.push_back(t.has_storage() ? c10::optional(t.storage()) : c10::nullopt); + impl_saved.push_back(t.getIntrusivePtr()); + }, &stack_args_copy, 0, num_arguments); + #endif + if (aliased_input_idx != -1 || any_is_inplace_output) { + at::AutoDispatchBelowAutograd guard; + op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); + } else { + // If neither in-place nor view + at::AutoDispatchBelowADInplaceOrView guard; + op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); + } + #ifndef NDEBUG + _foreach_tensor([&](size_t idx_tensor, size_t _, const at::Tensor& t) { + if (storage_saved.at(idx_tensor).has_value()) + TORCH_INTERNAL_ASSERT(storage_saved.at(idx_tensor).value().is_alias_of(t.storage()), op_name); + if (impl_saved.at(idx_tensor)) + TORCH_INTERNAL_ASSERT(impl_saved.at(idx_tensor) == t.getIntrusivePtr(), op_name); + }, &stack_args_copy, 0, num_arguments); + _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (!is_inplace_output[idx_ret]) + TORCH_INTERNAL_ASSERT(t.use_count() <= 1, op_name); // Okay to return undefined tensor + if (!is_aliased_output[idx_ret] && t.has_storage()) + TORCH_INTERNAL_ASSERT(t.storage().use_count() == 1); + }, stack, stack->size() - num_returns, num_returns); + // There should be only a single base-view pair, make sure their storage is aliased + if (aliased_input_idx != -1 && aliased_output_idx != -1) { + const c10::IValue& aliased_input_iv = stack_args_copy[aliased_input_idx]; + const c10::IValue& aliased_output_iv = (*stack)[stack->size() - num_returns + aliased_output_idx]; + // We do not support views embedded inside tensorlist + TORCH_INTERNAL_ASSERT(aliased_input_iv.isTensor(), op_name); + TORCH_INTERNAL_ASSERT(aliased_output_iv.isTensor(), op_name); + const at::Tensor& aliased_input = aliased_input_iv.toTensor(); + const at::Tensor& aliased_output = aliased_input_iv.toTensor(); + if(is_aliased_output[aliased_input_idx] && aliased_input.has_storage()) + TORCH_INTERNAL_ASSERT(aliased_input.storage().is_alias_of(aliased_output.storage()), op_name); + } + #endif + + if (any_requires_grad) { + _foreach_tensor([&](size_t idx_tensor, size_t idx_ret, const at::Tensor& t) { + if (isDifferentiableType(t.scalar_type())) { + if (is_inplace_output[idx_ret]) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + rebase_history(const_cast(t), grad_fn); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + set_history(const_cast(t), grad_fn); + } + } + }, stack, stack->size() - num_returns, num_returns); + } +} + +torch::CppFunction autogradNotImplementedFallback() { + return torch::CppFunction::makeFromBoxedFunction<&autogradNotImplementedFallbackImpl>(); +} + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.h b/torch/csrc/autograd/autograd_not_implemented_fallback.h new file mode 100644 index 0000000000000..4b2cbd14b9d86 --- /dev/null +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace torch { +namespace autograd { + +TORCH_API torch::CppFunction autogradNotImplementedFallback(); + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 502919ff3a6a4..1bb4cb836f1e8 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -26,17 +26,180 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { } } -std::vector> _wrap_outputs(const variable_list &input_vars, +// This function has two main goals: +// 1) Use the user-provided jvp function to populate the the outputs' forward gradient +// 2) Perform error checking to ensure that view and inplace ops are properly handled +// +// For 1) we have to: +// - Create a variable_list of grad_inputs based on the function inputs +// - Call the user jvp function with these to get the grad_outputs +// - Set the forward grad field on each output based on these grad_outputs +// +// For 2) we want to check the following: +// - If an output is a view, then the generated forward grad must be a view as well and +// the output's base's forward grad must be the output's forward grad's base. +// - If an input was modified inplace (it must be an output as well) we make sure that its +// forward grad was also modified inplace and already present on the corresponding output. +void _process_forward_mode_AD(const variable_list &inputs, + std::unordered_map inputs_mapping, + const at::ArrayRef> raw_outputs, + const optional_variable_list &outputs, + const std::unordered_set &non_differentiable, + const std::unordered_set &dirty_inputs, + _jvp_fn_t jvp_user_function) { + + // TODO handle multiple levels here + uint64_t level = 0; + + const auto num_inputs = inputs.size(); + const auto num_outputs = outputs.size(); + + // The tracking info below are used to perform the view and inplace checks. + // They are lazily initialized to reduce the cost of this function in the common + // case where the user is not using forward mode AD. + variable_list input_grads; + std::vector grad_versions; + std::vector grad_impls; + std::unordered_map inputs_bases; + + auto init_tracked_info = [&] () { + input_grads.resize(num_inputs); + grad_versions.resize(num_inputs); + grad_impls.resize(num_inputs); + + for (const auto i: c10::irange(num_inputs)) { + const auto& inp = inputs[i]; + if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) { + inputs_bases.emplace(impl::get_view_autograd_meta(inp)->get_forward_view().base_.unsafeGetTensorImpl(), i); + } else { + inputs_bases.emplace(inp.unsafeGetTensorImpl(), i); + } + + } + }; + + bool any_input_has_grad = false; + // Extract the input's forward gradients and record any info we will need later + for (const auto i : c10::irange(num_inputs)) { + const auto& inp = inputs[i]; + if (!inp.defined()) { + continue; + } + const auto& fw_grad = inp._fw_grad(level); + if (fw_grad.defined()) { + if (!any_input_has_grad) { + any_input_has_grad = true; + init_tracked_info(); + } + input_grads[i] = fw_grad; + grad_versions[i] = fw_grad._version(); + grad_impls[i] = fw_grad.unsafeGetTensorImpl(); + } + } + + // If no input has forward grad, nothing to do here + if (!any_input_has_grad) { + return; + } + + + auto forward_grads = jvp_user_function(inputs, input_grads); + + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const auto num_forward_grads = forward_grads.size(); + // contrary to backward mode, we don't allow returning too many gradients + TORCH_CHECK(num_forward_grads == num_outputs, "Function's jvp returned " + "an invalid number of of forward gradients (expected ", num_outputs, + " but got ", num_forward_grads, ")"); + + for (const auto i : c10::irange(num_outputs)) { + const auto& out = outputs[i].has_value()? outputs[i].value() : at::Tensor(); + const auto& out_grad = forward_grads[i]; + if (!out.defined()) { + TORCH_CHECK(!out_grad.defined(), "Function's jvp returned a gradient at position ", i, ", but " + " the corresponding forward output is not a differentiable Tensor"); + continue; + } + + TORCH_INTERNAL_ASSERT(raw_outputs[i].has_value()); + auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl(); + bool is_input = inputs_mapping.count(out_tensor_impl) > 0; + bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; + + if (is_modified) { + TORCH_CHECK(is_input, "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" + " is no need to pass it to mark_dirty()."); + auto inp_idx = inputs_mapping[out_tensor_impl]; + if (grad_impls[inp_idx]) { + // If there was already a forward grad for that input + // Just make sure that it is modified inplace and returned as-is + TORCH_CHECK(out_grad._version() != grad_versions[inp_idx], "An inplace custom Function is not modifying the " + "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp " + "function must modify the corresponding gradient inplace.") + TORCH_CHECK(out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx], "An inplace custom Function is not returning the " + "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp " + "function must modify the gradient inplace and return it as-is.") + } else { + // If that Tensor didn't had gradients already, set the newly returned one + // We could also use inputs[inp_idx] here as it is the same as out + out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); + } + } else { + // At this point, outputs[i] cannot be one of the input (raw_outputs[i] might be but was changed by the backward code) + TORCH_INTERNAL_ASSERT(!is_input); + + if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) { + // If the output is a view + const auto& out_view_info = impl::get_view_autograd_meta(out)->get_forward_view(); + if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) { + // And it is a view of an input (either that input is its base or they have a common base) + const auto matching_input_idx = inputs_bases[out_view_info.base_.unsafeGetTensorImpl()]; + const auto& matching_input = inputs[matching_input_idx]; + + const auto& matching_input_grad = matching_input._fw_grad(level); + + // If the matching input has a forward grad, the user should have returned a view of that Tensor + if (matching_input_grad.defined()) { + TORCH_CHECK(out_grad.is_view() && impl::get_view_autograd_meta(out_grad)->has_fw_view(), + "A custom Function's forward is returning a view but the jvp is not returning a view."); + + const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)->get_forward_view().base_; + if (matching_input_grad.is_view() && impl::get_view_autograd_meta(matching_input_grad)->has_fw_view()) { + // If the matching input's grad is a view, ensure that the out_grad is a view of the same base + const auto& matching_input_grad_base = impl::get_view_autograd_meta(matching_input_grad)->get_forward_view().base_; + TORCH_CHECK(matching_input_grad_base.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the same base as " + "the given grad input."); + } else { + // If the matching input's grad is not a view, then it must be the output gradient's base + TORCH_CHECK(matching_input_grad.unsafeGetTensorImpl() == out_grad_base.unsafeGetTensorImpl(), + "A custom Function is returning a view but the jvp is not returning a view of the given grad input."); + } + } else { + // We have a view op where the input didn't have a forward grad but the user returned one for the output + // To ensure that we maintain the view/inplace constraints, we consider this as an inplace op + // This case CANNOT happen in codegen as all view ops are mapping from one Tensor to one Tensor and so the output + // of the view cannot have a forward grad if the base does not. + out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); + return; + } + + } + } + + out._set_fw_grad(out_grad, level, /* is_inplace_op */ false); + } + } +} + +optional_variable_list _process_backward_mode_ad( + const std::unordered_map &inputs_mapping, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, const std::shared_ptr &cdata) { - std::unordered_set inputs; - inputs.reserve(input_vars.size()); - for (auto& var : input_vars) { - inputs.emplace(var.unsafeGetTensorImpl()); - } int num_outputs = raw_outputs.size(); @@ -63,7 +226,7 @@ std::vector> _wrap_outputs(const variable_list &input_va // Here, `y` requires_grad (!). } else if (is_modified) { if (var.is_leaf() && var.requires_grad()) { - throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); + TORCH_CHECK(false, "a leaf Variable that requires grad has been used in an in-place operation."); } // No need to mark as modified Tensors that are not inputs. if (!is_input) { @@ -105,7 +268,7 @@ std::vector> _wrap_outputs(const variable_list &input_va } }; - std::vector> outputs; + optional_variable_list outputs; std::unordered_set outputs_impl; // For dirty_inputs check outputs.reserve(num_outputs); int num_diff_outputs = 0; @@ -125,7 +288,7 @@ std::vector> _wrap_outputs(const variable_list &input_va Variable var = raw_outputs[i].value(); auto out_tensor_impl = var.unsafeGetTensorImpl(); - bool is_input = inputs.count(out_tensor_impl) > 0; + bool is_input = inputs_mapping.count(out_tensor_impl) > 0; bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0 && isDifferentiableType(var.scalar_type()); @@ -177,6 +340,30 @@ std::vector> _wrap_outputs(const variable_list &input_va return outputs; } + + +optional_variable_list _wrap_outputs(const variable_list &input_vars, + const std::unordered_set &non_differentiable, + const std::unordered_set &dirty_inputs, + const at::ArrayRef> raw_outputs, + const std::shared_ptr &cdata, + _jvp_fn_t jvp_user_function) { + + std::unordered_map inputs_mapping; + inputs_mapping.reserve(input_vars.size()); + for (const auto i: c10::irange(input_vars.size())) { + inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i); + } + + auto outputs = _process_backward_mode_ad(inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata); + + // This must happen after the backward processing as we expect the computations happening here to track + // backward mode gradients. + _process_forward_mode_AD(input_vars, inputs_mapping, raw_outputs, outputs, non_differentiable, dirty_inputs, jvp_user_function); + + return outputs; +} + void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) { if (!original.options().type_equal(result.options())) { std::stringstream ss; diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 243622f650666..94e62bf7b63c7 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -9,12 +9,16 @@ namespace torch { namespace autograd { +using optional_variable_list = std::vector>; +using _jvp_fn_t = std::function; + TORCH_API std::vector> _wrap_outputs( const variable_list &input_vars, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, - const std::shared_ptr &cdata); + const std::shared_ptr &cdata, + _jvp_fn_t jvp_user_function); TORCH_API void check_variable_result(const Variable& original, const Variable& result, std::string hook_name); @@ -263,12 +267,18 @@ auto Function::apply(Args&&... args) -> std::enable_if_t::v outputs = T::forward(&node->ctx_, std::forward(args)...); } + _jvp_fn_t jvp_fn = [](variable_list inputs, variable_list gI) -> variable_list { + TORCH_CHECK(false, "jvp is not implemented for the c++ API of custom Function yet.", + "Please open a feature request on Github if you need this."); + }; + auto wrapped_outputs = _wrap_outputs( input_vars, node->ctx_.get_non_differentiable(), node->ctx_.get_and_bump_dirty(), to_optional(outputs), - is_executable ? node : nullptr); + is_executable ? node : nullptr, + jvp_fn); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 252a74b4c07c7..4ea002a8312f1 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -407,7 +407,11 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { } if (task.fn_ && !local_graph_task->has_error_.load()) { - AutoGradMode grad_mode(local_graph_task->grad_mode_); + // Set the ThreadLocalState before calling the function. + // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask + // always saves ThreadLocalState without grad_mode. + at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); + try { // The guard sets the thread_local current_graph_task on construction // and restores it on exit. The current_graph_task variable helps @@ -415,7 +419,18 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { // callbacks. GraphTaskGuard guard(local_graph_task); NodeGuard ndguard(task.fn_); - evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_); + { + RECORD_FUNCTION( + c10::str( + "autograd::engine::evaluate_function: ", + task.fn_.get()->name()), + std::vector()); + evaluate_function( + local_graph_task, + task.fn_.get(), + task.inputs_, + local_graph_task->cpu_ready_queue_); + } } catch (std::exception& e) { thread_on_exception(local_graph_task, task.fn_, e); } @@ -764,11 +779,6 @@ void Engine::evaluate_function( Node* func, InputBuffer& inputs, const std::shared_ptr& cpu_ready_queue) { - // Set the ThreadLocalState before calling the function. - // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask - // always saves ThreadLocalState without grad_mode. - at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_); - // The InputBuffer::adds that supplied incoming grads took pains to // ensure they're safe to consume in the context of the present // func's stream (if applicable). So we guard onto that stream diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 17318473bcfcd..dd465f96c350e 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -53,9 +53,8 @@ struct GraphTask: std::enable_shared_from_this { // true, it signals all threads to stop executing. std::atomic_bool has_error_{false}; std::atomic_bool future_completed_{false}; - // It is safe to read grad_mode_ and keep_graph_ without synchronization + // It is safe to read keep_graph_ without synchronization bool keep_graph_; - bool grad_mode_; // To protect reads/writes to not_ready_, dependencies_, captured_vars_, // has_error_, future_result_, cpu_ready_queue_, and leaf_streams. @@ -110,8 +109,9 @@ struct GraphTask: std::enable_shared_from_this { // out of the GraphTask and are no longer valid. std::vector captured_vars_; - at::ThreadLocalState thread_locals_ = - at::ThreadLocalState(/* keep_grad_mode */ false); + // Note: this field is not ready to be used until the proper `thread_locals_.set_grad_mode()` + // call in the constructor. + at::ThreadLocalState thread_locals_ = at::ThreadLocalState(); std::unordered_set leaf_streams; @@ -180,12 +180,13 @@ struct GraphTask: std::enable_shared_from_this { std::shared_ptr cpu_ready_queue, bool exit_on_error = false) : keep_graph_(keep_graph), - grad_mode_(grad_mode), owner_(NO_DEVICE), reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), cpu_ready_queue_(std::move(cpu_ready_queue)), - future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) {} + future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) { + thread_locals_.set_grad_mode(grad_mode); + } private: // run GraphTask post processing void exec_post_processing(); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 25336dfa9d911..2a1de8e82a774 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -563,6 +563,14 @@ struct MakeNextFunctionList : IterArgs { next_edges.emplace_back(); } } + void operator()(const Variable* variable) { + // NOLINTNEXTLINE(bugprone-branch-clone) + if (variable->defined()) { + next_edges.push_back(impl::gradient_edge(*variable)); + } else { + next_edges.emplace_back(); + } + } void operator()(const c10::optional& variable) { // NOLINTNEXTLINE(bugprone-branch-clone) if (variable.has_value() && variable->defined()) { diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 90811e2a30a37..331db5d32cb79 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -86,4 +86,9 @@ inline void set_history( set_history(variable, grad_fn); } } + +inline bool isFwGradDefined(const c10::optional& t) { + return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined(); +} + }} diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 2eacbf1cd3839..697ca871f83c5 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -114,7 +114,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("IDEEP", c10::DeviceType::IDEEP) .value("HIP", c10::DeviceType::HIP) .value("FPGA", c10::DeviceType::FPGA) - .value("MSNPU", c10::DeviceType::MSNPU) + .value("ORT", c10::DeviceType::ORT) .value("XLA", c10::DeviceType::XLA) .value("Lazy", c10::DeviceType::Lazy) .value("MLC", c10::DeviceType::MLC) diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index dd58a68134b8f..eee56f71ed7d8 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -45,14 +45,29 @@ PyObject *THPFunctionClass = nullptr; #define THPFunction_assert(condition, ...) \ if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } -namespace torch { namespace autograd { +// Anonymous namespace for helpful functions used in this file +namespace { -void PyNode::throw_python_error() { +// Throw a python_error with the PyErr state persisted, so that we +// don't lose the error state if the GIL is released when we don't +// have a PyThreadState created beforehand, this is made so that +// even for pure C++ thread without a pre-created PyThreadState could +// also capture the correct error message. +// TODO: This is a temporary approach to allow C++ thread to correctly +// capture Python Error in autograd, remove this when c10 thread pool +// allow to do one time initialization. +// see discussion in https://github.com/pytorch/pytorch/pull/34845 +// Follow up issue: https://github.com/pytorch/pytorch/issues/35006 +void throw_python_error() { python_error err; err.persist(); throw err; } +} + +namespace torch { namespace autograd { + // NOTE: this function is written in a way that assumes it's only called for backward; // it's used by engine.cpp. This is responsible for forwarding a call from // C++'s Node::apply to a Python method "apply". @@ -325,8 +340,61 @@ static void _wrap_outputs(const std::shared_ptr& cdata, THPFunction *sel } } + _jvp_fn_t jvp_user_function = [self](variable_list inputs, variable_list grad_inputs) { + pybind11::gil_scoped_acquire gil; + + // Massage a C++ variable_list into a Python arguments tuple + // Making sure to introduce the proper None for non-Tensor inputs + auto num_inputs = self->is_variable_input.size(); + THPObjectPtr pyInputs(PyTuple_New(num_inputs)); + if (!pyInputs) throw_python_error(); + auto var_input_idx = 0; + for (const auto i : c10::irange(num_inputs)) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + PyObject* input; + if (self->is_variable_input[i]) { + if (grad_inputs[i].defined() || !self->materialize_grads) { + input = THPVariable_Wrap(grad_inputs[i]); + } else { + input = THPVariable_Wrap(at::zeros_like(inputs[i])); + } + if (!input) throw_python_error(); + } else { + Py_INCREF(Py_None); + input = Py_None; + } + PyTuple_SET_ITEM(pyInputs.get(), i, input); + } + + THPObjectPtr apply_jvp_fn(PyObject_GetAttrString((PyObject*)self, "apply_jvp")); + if (!apply_jvp_fn) throw_python_error(); + THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get())); + if (!r) throw_python_error(); + ensure_tuple(r); + + // Massage the Python results tuple back into a C++ variable_list + // Don't do any check on the number of results here as + // it is handled by the caller + const int num_outputs = PyTuple_GET_SIZE(r.get()); + variable_list results; + results.reserve(num_outputs); + for (int i = 0; i != num_outputs; ++i) { + PyObject* output = PyTuple_GET_ITEM(r.get(), i); + if (output == Py_None) { + results.emplace_back(); + } else { + TORCH_CHECK(THPVariable_Check(output), "expected Variable or None (got ", + THPUtils_typename(output), ") for grad output ", i, ".") + results.emplace_back(THPVariable_Unpack(output)); + } + } + + return results; + }; + // Wrap only the tensor outputs. - auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata_if_executable); + auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, + raw_output_vars, cdata_if_executable, jvp_user_function); for(const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GetItem(raw_output, i); @@ -556,6 +624,9 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr& cdata bool is_inplace = static_cast(grad_fn->dirty_tensors); _wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); + + // It is important that creating the SavedVariables happen after the output wrapping as the + // outputs must have their grad_fn/fw_grad properly set before we save them. if (is_executable) { _save_variables(cdata, grad_fn); } else { @@ -636,6 +707,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) THPObjectPtr tensor_outputs; { AutoGradMode grad_mode(false); + at::AutoFwGradMode fw_grad_mode(false); THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return nullptr; tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 8f4d12ba640fc..3657807f35964 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -27,17 +27,6 @@ struct PyNode : public Node { variable_list apply(variable_list&& inputs) override; - // Throw a python_error with the PyErr state persisted, so that we - // don't lose the error state if the GIL is released when we don't - // have a PyThreadState created beforehand, this is made so that - // even for pure C++ thread without a pre-created PyThreadState could - // also capture the correct error message. - // TODO: This is a temporary approach to allow C++ thread to correctly - // capture Python Error in autograd, remove this when c10 thread pool - // allow to do one time initialization. - // see discussion in https://github.com/pytorch/pytorch/pull/34845 - // Follow up issue: https://github.com/pytorch/pytorch/issues/35006 - void throw_python_error(); void release_variables() override; std::string name() const override; bool is_traceable() override; diff --git a/torch/csrc/autograd/python_torch_functions.h b/torch/csrc/autograd/python_torch_functions.h new file mode 100644 index 0000000000000..58257794812ee --- /dev/null +++ b/torch/csrc/autograd/python_torch_functions.h @@ -0,0 +1,25 @@ +#include + +#include + + +namespace torch { namespace autograd { + +extern PyObject* THPVariableFunctionsModule; + +// Wrapper converts a raised TypeError into returning NotImplemented +// Used to implement binary arithmetic operators +template +inline PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { + PyObject* ret = Func(self, args, kwargs); + if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + ret = Py_NotImplemented; + } + return ret; +} + +void initTorchFunctions(); + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp new file mode 100644 index 0000000000000..a54d1017bcee8 --- /dev/null +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -0,0 +1,826 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; +using at::ArrayRef; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +namespace torch { namespace autograd { + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +PyObject* THPVariableFunctionsModule = nullptr; + + +inline Tensor dispatch_arange(const Scalar& end, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::arange_out(result, end); +} + +inline Tensor dispatch_arange(const Scalar& end, const TensorOptions& options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::arange(end, options); +} + +inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::arange_out(result, start, end, step); +} + +inline Tensor dispatch_arange(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::arange(start, end, step, options); +} + +static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, /*traceable=*/true); + + ParsedArgs<9> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + } + + if (r.idx == 0) { + if (r.isNone(1)) { + auto end = r.scalar(0); + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + c10::optional scalarType = r.scalartypeOptional(2); + const auto options = TensorOptions() + .dtype(scalarType) + .device(r.device(4)) + .layout(r.layout(3)) + .requires_grad(r.toBool(6)) + .pinned_memory(r.toBool(5)); + return wrap(dispatch_arange(end, options)); + } else { + TORCH_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), + r.device(4), r.isNone(4)); + return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6))); + } + } else if (r.idx == 1) { + if (r.isNone(3)) { + auto start = r.scalar(0); + auto end = r.scalar(1); + auto step = r.scalar(2); + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + c10::optional scalarType = r.scalartypeOptional(4); + const auto options = TensorOptions() + .dtype(scalarType) + .device(r.device(6)) + .layout(r.layout(5)) + .requires_grad(r.toBool(8)) + .pinned_memory(r.toBool(7)); + return wrap(dispatch_arange(start, end, step, options)); + } else { + TORCH_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), + r.device(6), r.isNone(6)); + return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8))); + } + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, Tensor result) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(result)); + return at::range_out(result, start, end, step); +} + +inline Tensor dispatch_range(const Scalar& start, const Scalar& end, const Scalar& step, const TensorOptions& options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + DeviceGuard device_guard(options.device()); + return torch::range(start, end, step, options); +} + +static PyObject * THPVariable_range(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "range(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + }); + + ParsedArgs<8> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if (r.idx == 0) { + auto ret = PyErr_WarnEx( + PyExc_UserWarning, + "torch.range is deprecated and will be removed in a future release " + "because its behavior is inconsistent with Python's range builtin. " + "Instead, use torch.arange, which produces values in [start, end).", + 1); + if (ret != 0) throw python_error(); + if (r.isNone(3)) { + const auto options = TensorOptions() + .dtype(r.scalartype(4)) + .device(r.device(6)) + .layout(r.layout(5)) + .requires_grad(r.toBool(7)); + return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), options)); + } else { + check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), + r.layout(5), r.device(6), r.isNone(6)); + return wrap(dispatch_range(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7))); + } + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +inline Tensor dispatch_full( + IntArrayRef size, + const Scalar& fill_val, + const TensorOptions& options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return at::full(size, fill_val, options); +} + +inline Tensor dispatch_full( + IntArrayRef size, + const Scalar& fill_val, + c10::optional names, + const TensorOptions& options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return at::full(size, fill_val, names, options); +} + +inline Tensor dispatch_full( + IntArrayRef size, + const Scalar& fill_val, + Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::full_out(result, size, fill_val); +} + +static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + + static PythonArgParser parser({ + "full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + "full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, /*traceable=*/true); + + // Acquires (common) arguments + ParsedArgs<8> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + } + + auto size = r.intlist(0); + auto fill_val = r.scalar(1); + const auto options = TensorOptions{} + .dtype(r.scalartypeOptional(3)) + .layout(r.layout(4)) + .device(r.device(5)) + .pinned_memory(r.toBool(6)); + + if (r.idx == 0) { + // full + if (r.isNone(2)) { + return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7))); + } + + // full.out + // Validates out tensor and other kwargs + auto result = r.tensor(2); + TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible"); + check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4), + r.device(5), r.isNone(5)); + + return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7))); + } else if (r.idx == 1) { + // full.names + if (r.isNone(2)) { + return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7))); + } + + // Converts from c10::optional to c10::optional + auto raw_names = r.toDimnameListOptional(2); + c10::optional names(*raw_names); + return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7))); + } + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::randint_out(result, high, size, generator); +} +inline Tensor dispatch_randint(int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::randint(high, size, generator, options); +} +inline Tensor dispatch_randint(int64_t high, IntArrayRef size, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::randint_out(result, high, size); +} +inline Tensor dispatch_randint(int64_t high, IntArrayRef size, const TensorOptions & options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::randint(high, size, options); +} +inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::randint_out(result, low, high, size, generator); +} +inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, c10::optional generator, const TensorOptions & options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::randint(low, high, size, generator, options); +} +inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, Tensor result) { + pybind11::gil_scoped_release no_gil; + return at::randint_out(result, low, high, size); +} +inline Tensor dispatch_randint(int64_t low, int64_t high, IntArrayRef size, const TensorOptions & options) { + torch::utils::maybe_initialize_cuda(options); + pybind11::gil_scoped_release no_gil; + return torch::randint(low, high, size, options); +} + +static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "randint(int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + "randint(int64_t low, int64_t high, IntArrayRef size, *, Generator generator=None, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + }, /*traceable=*/false); + + ParsedArgs<9> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + } + + if (r.idx == 0) { + if (r.isNone(3)) { + auto high = r.toInt64(0); + auto size = r.intlist(1); + auto generator = r.generator(2); + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + auto dtype = r.scalartypeWithDefault(4, at::ScalarType::Long); + auto device = r.device(6); + const auto options = TensorOptions() + .dtype(dtype) + .device(device) + .layout(r.layout(5)) + .requires_grad(r.toBool(7)); + return wrap(dispatch_randint(high, size, generator, options)); + } else { + check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), + r.layout(5), r.device(6), r.isNone(6)); + return wrap(dispatch_randint(r.toInt64(0), r.intlist(1), r.generator(2), r.tensor(3)).set_requires_grad(r.toBool(7))); + } + } else if (r.idx == 1) { + if (r.isNone(4)) { + auto low = r.toInt64(0); + auto high = r.toInt64(1); + auto size = r.intlist(2); + auto generator = r.generator(3); + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + auto dtype = r.scalartypeWithDefault(5, at::ScalarType::Long); + auto device = r.device(7); + const auto options = TensorOptions() + .dtype(dtype) + .device(device) + .layout(r.layout(6)) + .requires_grad(r.toBool(8)); + return wrap(dispatch_randint(low, high, size, generator, options)); + } else { + check_out_type_matches(r.tensor(4), r.scalartype(5), r.isNone(5), + r.layout(6), r.device(7), r.isNone(7)); + return wrap(dispatch_randint(r.toInt64(0), r.toInt64(1), r.intlist(2), r.generator(3), r.tensor(4)).set_requires_grad(r.toBool(8))); + } + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// implemented on python object to allow torch.as_tensor to be constructed with arbitrarily nested +// python objects - list, tuple, np array, scalar, etc. +static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +// implemented on python object here because PyObject currently not natively declarable +// See: ATen/native/README.md for more context +static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch.from_numpy", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::tensor_from_numpy(arg)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_nonzero(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero(); +} + +static Tensor dispatch_nonzero(const Tensor & self, Tensor out) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return at::nonzero_out(out, self); +} + +static std::vector dispatch_nonzero_numpy(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero_numpy(); +} + +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs); + +static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable__sparse_coo_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch._sparse_coo_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::_sparse_coo_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +// implemented on python object to allow torch.tensor to be constructed with arbitrarily nested +// python objects - list, tuple, np array, scalar, etc. +static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "get_device(Tensor input)", + }, /*traceable=*/false); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if (r.idx == 0) { + return wrap(r.tensor(0).get_device()); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +}static PyObject * THPVariable_frombuffer(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "frombuffer(PyObject* buffer, *, ScalarType dtype, int64_t count=-1, int64_t offset=0, bool requires_grad=False)", + }, /*traceable=*/false); + + PyObject* ret = nullptr; + ParsedArgs<5> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if (r.idx == 0) { + auto buffer = r.pyobject(0); + auto dtype = r.scalartype(1); + auto count = r.toInt64(2); + auto offset = r.toInt64(3); + auto requires_grad = r.toBool(4); + + auto elsize = at::elementSize(dtype); + size_t actual_count = 0; + Py_buffer view; + + TORCH_CHECK_VALUE( + PyObject_CheckBuffer(buffer) != 0, + "object does not implement Python buffer protocol."); + + if (PyObject_GetBuffer(buffer, &view, PyBUF_WRITABLE) < 0) { + TORCH_CHECK( + PyObject_GetBuffer(buffer, &view, PyBUF_SIMPLE) >= 0, + "could not retrieve buffer from object"); + TORCH_WARN_ONCE( + "The given buffer is not writable, and PyTorch does " + "not support non-writable tensors. This means you can write to the " + "underlying (supposedly non-writable) buffer using the tensor. " + "You may want to copy the buffer to protect its data or make it writable " + "before converting it to a tensor. This type of warning will be " + "suppressed for the rest of this program."); + PyErr_Clear(); + } + + Py_INCREF(view.obj); + THPObjectPtr obj(view.obj); + + auto len = view.len; + auto buf = view.buf; + PyBuffer_Release(&view); + + TORCH_CHECK_VALUE( + len > 0 && count != 0, + "both buffer length (", len, ") and count (", count, ") must not be 0"); + TORCH_CHECK_VALUE( + offset >= 0 && offset < len, + "offset (", offset, " bytes) must be non-negative and no greater than " + "buffer length (", len, " bytes) minus 1"); + TORCH_CHECK_VALUE( + count > 0 || (len - offset) % elsize == 0, + "buffer length (", len - offset, " bytes) after offset (", offset, " bytes) " + "must be a multiple of element size (", elsize, ")"); + + if (count < 0) { + actual_count = (len - offset) / elsize; + } else { + actual_count = static_cast(count); + } + + TORCH_CHECK_VALUE( + static_cast(offset) + actual_count * elsize <= len, + "requested buffer length (", actual_count, " * ", elsize, " bytes) " + "after offset (", offset, " bytes) must not be greater than actual " + "buffer length (", len, " bytes)"); + + auto offset_buf = static_cast(buf) + offset; + auto options = TensorOptions() + .dtype(dtype) + .device(c10::kCPU); + + auto tensor = at::for_blob(offset_buf, static_cast(actual_count)) + .options(options) + .deleter([obj = obj.release()](void*) { + pybind11::gil_scoped_acquire gil; + Py_DECREF(obj); + }) + .make_tensor(); + tensor.set_requires_grad(requires_grad); + ret = wrap(tensor); + } + + return ret; + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs); + +// linspace +static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "linspace(Scalar start, Scalar end, int64_t? steps=None, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, /*traceable=*/true); + + ParsedArgs<9> parsed_args; + auto _r = parser.parse(nullptr, args, kwargs, parsed_args); + if(_r.has_torch_function()) { + return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + } + if (_r.isNone(3)) { + // aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + // This leads to problem in the operator argument checks, + // when either `start` or `end` is complex and dtype is None + const auto options = TensorOptions() + .dtype(_r.scalartypeOptional(4)) + .device(_r.device(6)) + .layout(_r.layoutOptional(5)) + .requires_grad(_r.toBool(8)) + .pinned_memory(_r.toBool(7)); + torch::utils::maybe_initialize_cuda(options); + + auto dispatch_linspace = [](Scalar start, Scalar end, c10::optional steps, TensorOptions options) -> Tensor { + pybind11::gil_scoped_release no_gil; + return torch::linspace(start, end, steps, options); + }; + return wrap(dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), options)); + } else { + // aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!) + check_out_type_matches(_r.tensor(3), _r.scalartype(4), + _r.isNone(4), _r.layoutOptional(5), + _r.device(6), _r.isNone(6)); + + auto dispatch_linspace_out = [](Tensor out, Scalar start, Scalar end, c10::optional steps) -> Tensor { + pybind11::gil_scoped_release no_gil; + return at::linspace_out(out, start, end, steps); + }; + return wrap(dispatch_linspace_out(_r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64Optional(2)).set_requires_grad(_r.toBool(8))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// logspace +static PyObject * THPVariable_logspace(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "logspace(Scalar start, Scalar end, int64_t? steps=None, double base=10.0, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + }, /*traceable=*/true); + + ParsedArgs<10> parsed_args; + auto _r = parser.parse(nullptr, args, kwargs, parsed_args); + if(_r.has_torch_function()) { + return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + } + if (_r.isNone(4)) { + // aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + + // NOTE: r.scalartype(X) gives the default dtype if r.isNone(X) + // This leads to problem in the operator argument checks, + // when either `start` or `end` is complex and dtype is None + const auto options = TensorOptions() + .dtype(_r.scalartypeOptional(5)) + .device(_r.device(7)) + .layout(_r.layoutOptional(6)) + .requires_grad(_r.toBool(9)) + .pinned_memory(_r.toBool(8)); + torch::utils::maybe_initialize_cuda(options); + + auto dispatch_logspace = [](Scalar start, Scalar end, c10::optional steps, double base, TensorOptions options) -> Tensor { + pybind11::gil_scoped_release no_gil; + return torch::logspace(start, end, steps, base, options); + }; + return wrap(dispatch_logspace(_r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), _r.toDouble(3), options)); + } else { + // aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + check_out_type_matches(_r.tensor(4), _r.scalartype(5), + _r.isNone(5), _r.layoutOptional(6), + _r.device(7), _r.isNone(7)); + + auto dispatch_logspace_out = [](Tensor out, Scalar start, Scalar end, c10::optional steps, double base) -> Tensor { + pybind11::gil_scoped_release no_gil; + return at::logspace_out(out, start, end, steps, base); + }; + return wrap(dispatch_logspace_out(_r.tensor(4), _r.scalar(0), _r.scalar(1), _r.toInt64Optional(2), _r.toDouble(3)).set_requires_grad(_r.toBool(9))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// XXX: ops that are bound here are not exposed to the C++ api nor the JIT. +// Any new ops added here should be accompanied with a comment why they are not +// being registered through native_functions.yaml, and be tagged cpp / JIT +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +static PyMethodDef torch_functions_manual[] = { + {"arange", castPyCFunctionWithKeywords(THPVariable_arange), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, nullptr}, + {"frombuffer", castPyCFunctionWithKeywords(THPVariable_frombuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"linspace", castPyCFunctionWithKeywords(THPVariable_linspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"logspace", castPyCFunctionWithKeywords(THPVariable_logspace), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, +}; + +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)", + }); + ParsedArgs<3> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + } + + const auto as_tuple = r.toBool(1); + const auto has_out = !r.isNone(2); + + if (as_tuple) { + TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True"); + return wrap(dispatch_nonzero_numpy(r.tensor(0))); + } + + if (has_out) { + return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2))); + } + + return wrap(dispatch_nonzero(r.tensor(0))); + + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_numel(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "numel(Tensor input)", + }, /*traceable=*/false); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch"); + } + + if (r.idx == 0) { + return wrap(r.tensor(0).numel()); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// Sharded function definitions +void gatherTorchFunctions_0(std::vector &torch_functions); +void gatherTorchFunctions_1(std::vector &torch_functions); +void gatherTorchFunctions_2(std::vector &torch_functions); + +void gatherTorchFunctions(std::vector &torch_functions) { + constexpr size_t num_functions = sizeof(torch_functions_manual) / sizeof(torch_functions_manual[0]); + torch_functions.assign(torch_functions_manual, + torch_functions_manual + num_functions); + // NOTE: Must be synced with num_shards in tools/autograd/gen_python_functions.py + gatherTorchFunctions_0(torch_functions); + gatherTorchFunctions_1(torch_functions); + gatherTorchFunctions_2(torch_functions); + + static std::array, 4> aliases{{ + // Canonical function, alias name + {"sspaddmm", "saddmm"}, + {"mm", "spmm"}, + {"mm", "dsmm"}, + {"hspmm", "hsmm"} + }}; + + for (const auto& alias : aliases) { + auto it = std::find_if(torch_functions.begin(), torch_functions.end(), + [&](const PyMethodDef& def) { + return strcmp(def.ml_name, alias.first) == 0; + }); + TORCH_INTERNAL_ASSERT( + it != torch_functions.end(), + "Failed to create function alias from ", alias.first, " to ", alias.second); + PyMethodDef alias_def = *it; + alias_def.ml_name = alias.second; + + torch_functions.push_back(alias_def); + } + + torch_functions.push_back({nullptr}); + torch_functions.shrink_to_fit(); +} + +static PyTypeObject THPVariableFunctions = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch._C._VariableFunctionsClass", /* tp_name */ + 0, /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr /* tp_new */ +}; + +void initTorchFunctions(PyObject *module) { + static std::vector torch_functions; + gatherTorchFunctions(torch_functions); + THPVariableFunctions.tp_methods = torch_functions.data(); + + if (PyType_Ready(&THPVariableFunctions) < 0) { + throw python_error(); + } + Py_INCREF(&THPVariableFunctions); + + // Steals + Py_INCREF(&THPVariableFunctions); + if (PyModule_AddObject(module, "_VariableFunctionsClass", + reinterpret_cast(&THPVariableFunctions)) < 0) { + throw python_error(); + } + // PyType_GenericNew returns a new reference + THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None); + // PyModule_AddObject steals a reference + if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) { + throw python_error(); + } +} + +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 9496d668b3468..50d6eb9ab7e05 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -834,6 +834,17 @@ PyObject *THPVariable_is_mlc(THPVariable *self, void *unused) END_HANDLE_TH_ERRORS } +PyObject *THPVariable_is_ort(THPVariable *self, void *unused) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_ort"); + } + auto& self_ = THPVariable_Unpack(self); + return torch::autograd::utils::wrap(self_.is_ort()); + END_HANDLE_TH_ERRORS +} + PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) { HANDLE_TH_ERRORS @@ -980,6 +991,7 @@ static struct PyGetSetDef THPVariable_properties[] = { {"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr}, {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, {"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr}, + {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr}, {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, @@ -1562,7 +1574,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa if (ivalue.isTensor()) { const auto& tensor = ivalue.toTensor(); if (isPythonTensor(tensor)) { - overloaded_args.emplace_back(py::cast(tensor)); + append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); } } else if (ivalue.isList()) { const auto& list = ivalue.toListRef(); @@ -1571,7 +1583,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa if (nv.isTensor()) { const auto& tensor = nv.toTensor(); if (isPythonTensor(tensor)) { - overloaded_args.emplace_back(py::cast(tensor)); + append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); } } } @@ -1620,7 +1632,8 @@ c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter // TODO: fix the constness of target Tensor self_t = Tensor(c10::intrusive_ptr::unsafe_reclaim_from_nonowning(const_cast(self))); auto self_p = py::reinterpret_steal(THPVariable_Wrap(self_t)); - overloaded_args.emplace_back(self_p); + TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); + append_overloaded_arg(&overloaded_args, self_p.ptr()); auto args = py::reinterpret_steal(PyTuple_New(1)); PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp index 7e621f9e8b62e..9650c354c5868 100644 --- a/torch/csrc/autograd/record_function_ops.cpp +++ b/torch/csrc/autograd/record_function_ops.cpp @@ -79,7 +79,7 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { jit::RegisterOperators reg_fut_ops({ jit::Operator( "profiler::_call_end_callbacks_on_jit_fut(Tensor x, Future(t) y) -> Future(t)", - [](jit::Stack* stack) { + [](jit::Stack& stack) { // Pop inputs, which should be a future and a tensor auto fut = jit::pop(stack).toFuture(); auto tensor = jit::pop(stack).toTensor(); diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 32af5f97ad4e4..7ae1ac0bdee8d 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -551,10 +551,10 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso // self = view_op_n(view_n-1) // self = inplace_op(self) // - // For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of + // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of // view backward ops for effienciency. // - // However in XLA backend we don't have full support of AsStridedBackward, we instead run a full + // However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full // forward pass with a tensor that requires gradient to get proper grad_fn setup, // then save it to DifferentiableViewMeta for future use. // This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA). @@ -572,7 +572,7 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso auto diff_view = view_fn(view_info.base_); diff_view_meta->grad_fn_ = diff_view.grad_fn(); } else { - auto fn = std::make_shared(); + auto fn = std::make_shared(); fn->self_geometry = at::TensorGeometry(view_info.base_); fn->size = self.sizes().vec(); fn->stride = self.strides().vec(); diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index 123abb9666ee5..beacefa3f8878 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -23,36 +23,29 @@ void THCPGraph_init(PyObject *module) { auto torch_C_m = py::handle(module).cast(); torch_C_m - .def("_graph_pool_handle", &::at::cuda::graph_pool_handle); + .def("_graph_pool_handle", + &::at::cuda::graph_pool_handle); - shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CudaGraphBase") + shared_ptr_class_<::at::cuda::CUDAGraph> + (torch_C_m, + "_CUDAGraph") .def(py::init<>()) // I'm not sure this is the correct order of all the arguments. Pybind11 docs // aren't clear. But it works. .def("capture_begin", &::at::cuda::CUDAGraph::capture_begin, py::call_guard(), - R"(``capture_begin`` begins Cuda graph capture on the current stream.)", py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) .def("capture_end", &::at::cuda::CUDAGraph::capture_end, - py::call_guard(), - R"(``capture_end`` ends Cuda graph capture on the current stream. - After ``capture_end``, ``replay`` may be called on this instance.)") + py::call_guard()) .def("replay", &::at::cuda::CUDAGraph::replay, - py::call_guard(), - R"(``replay`` replays the Cuda graph captured by this instance.)") - // reset is called in __del__ on the Python side - // (see class Graph in torch/cuda/streams.py for reasons and caveats) + py::call_guard()) .def("reset", &::at::cuda::CUDAGraph::reset, - py::call_guard(), - R"(``reset`` deletes the graph currently held by this instance.)") + py::call_guard()) .def("pool", &::at::cuda::CUDAGraph::pool, - py::call_guard(), - R"(``pool`` retrieves the id of this graph's memory pool. - This id can optionally be passed to another graph's capture_begin, - which hints that other graph may share the same memory pool.)"); + py::call_guard()); } diff --git a/torch/csrc/deploy/deploy.h b/torch/csrc/deploy/deploy.h index 20364797edd8a..f34e4bc5fdbcc 100644 --- a/torch/csrc/deploy/deploy.h +++ b/torch/csrc/deploy/deploy.h @@ -232,6 +232,10 @@ class PythonMethodWrapper : public torch::IMethod { std::string method_name) : model_(std::move(model)), method_name_(std::move(method_name)) {} + const std::string& name() const override { + return method_name_; + } + c10::IValue operator()( std::vector args, const IValueMap& kwargs = IValueMap()) const override { diff --git a/torch/csrc/deploy/example/benchmark.cpp b/torch/csrc/deploy/example/benchmark.cpp index 348d84fec02b4..d2f1142965d40 100644 --- a/torch/csrc/deploy/example/benchmark.cpp +++ b/torch/csrc/deploy/example/benchmark.cpp @@ -295,6 +295,7 @@ struct Benchmark { std::function run_one_work_item; }; +// NOLINTNEXTLINE(bugprone-exception-escape) int main(int argc, char* argv[]) { int max_thread = atoi(argv[1]); cuda = std::string(argv[2]) == "cuda"; diff --git a/torch/csrc/deploy/example/generate_examples.py b/torch/csrc/deploy/example/generate_examples.py index 65f244373d954..0f279d922157c 100644 --- a/torch/csrc/deploy/example/generate_examples.py +++ b/torch/csrc/deploy/example/generate_examples.py @@ -79,3 +79,6 @@ def save(name, model, model_jit=None, eg=None, featurestore_meta=None): e.save_pickle("fn", "fn.pkl", load_library) generate_fx_example() + + with PackageExporter(p / "uses_distributed") as e: + e.save_source_string("uses_distributed", "import torch.distributed; assert torch.distributed.is_available()") diff --git a/torch/csrc/deploy/interpreter/freeze.py b/torch/csrc/deploy/interpreter/freeze.py index 24fa709cb01ac..31531746ed1b2 100644 --- a/torch/csrc/deploy/interpreter/freeze.py +++ b/torch/csrc/deploy/interpreter/freeze.py @@ -35,17 +35,13 @@ """ -MAIN_PREFIX = """ +MAIN_PREFIX_TEMPLATE = """ // Compiled standard library modules. These should be appended to the existing // `PyImport_FrozenModules` that ships with CPython. -struct _frozen _PyImport_FrozenModules_torch[] = { +struct _frozen {}[] = {{ """ -FAKE_PREFIX = """ -// Compiled standard library modules. These should be appended to the existing -// `PyImport_FrozenModules` that ships with CPython. -struct _frozen _PyImport_FrozenModules[] = { -""" +FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules") MAIN_SUFFIX = """\ {0, 0, 0} /* sentinel */ @@ -133,7 +129,7 @@ def write_bytecode(self, install_root): for f in bytecode_files: f.close() - def write_main(self, install_root, oss): + def write_main(self, install_root, oss, symbol_name): """ Write the `main.c` file containing a table enumerating all the frozen modules. @@ -143,7 +139,7 @@ def write_main(self, install_root, oss): for m in self.frozen_modules: outfp.write(f"extern unsigned char {m.c_name}[];\n") - outfp.write(MAIN_PREFIX) + outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name)) for m in self.frozen_modules: outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n') outfp.write(MAIN_SUFFIX) @@ -246,6 +242,11 @@ def compile_file(self, path: Path, top_package_path: Path): parser.add_argument("--verbose", action="store_true", help="Print debug logs") parser.add_argument("--install_dir", help="Root directory for all output files") parser.add_argument("--oss", action="store_true", help="If it's OSS build, add a fake _PyImport_FrozenModules") +parser.add_argument( + "--symbol_name", + help="The name of the frozen module array symbol to generate", + default="_PyImport_FrozenModules_torch", +) args = parser.parse_args() @@ -264,4 +265,4 @@ def compile_file(self, path: Path, top_package_path: Path): f.compile_path(path, path) f.write_bytecode(args.install_dir) -f.write_main(args.install_dir, args.oss) +f.write_main(args.install_dir, args.oss, args.symbol_name) diff --git a/torch/csrc/deploy/test_deploy.cpp b/torch/csrc/deploy/test_deploy.cpp index f88a23c43bde0..53456cacca2ad 100644 --- a/torch/csrc/deploy/test_deploy.cpp +++ b/torch/csrc/deploy/test_deploy.cpp @@ -63,7 +63,7 @@ TEST(TorchpyTest, InitTwice) { TEST(TorchpyTest, DifferentInterps) { torch::deploy::InterpreterManager m(2); m.register_module_source("check_none", "check = id(None)\n"); - int64_t id0, id1; + int64_t id0 = 0, id1 = 0; { auto I = m.all_instances()[0].acquire_session(); id0 = I.global("check_none", "check").toIValue().toInt(); @@ -312,6 +312,7 @@ TEST(TorchpyTest, SharedLibraryLoad) { I.global("sys", "path").attr("append")({"torch/csrc/deploy"}); I.global("test_deploy_python", "setup")({getenv("PATH")}); } else { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) char buf[PATH_MAX]; strncpy(buf, test_lib_path, PATH_MAX); dirname(buf); @@ -365,3 +366,15 @@ TEST(TorchpyTest, SharedLibraryLoad) { } } #endif + +TEST(TorchpyTest, UsesDistributed) { + const auto model_filename = path( + "USES_DISTRIBUTED", + "torch/csrc/deploy/example/generated/uses_distributed"); + torch::deploy::InterpreterManager m(1); + torch::deploy::Package p = m.load_package(model_filename); + { + auto I = p.acquire_session(); + I.self.attr("import_module")({"uses_distributed"}); + } +} diff --git a/torch/csrc/deploy/test_deploy_gpu.cpp b/torch/csrc/deploy/test_deploy_gpu.cpp index 8287d1683edca..4e990adcd9e89 100644 --- a/torch/csrc/deploy/test_deploy_gpu.cpp +++ b/torch/csrc/deploy/test_deploy_gpu.cpp @@ -53,3 +53,15 @@ TEST(TorchDeployGPUTest, SimpleModel) { ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05)); } + +TEST(TorchDeployGPUTest, UsesDistributed) { + const auto model_filename = path( + "USES_DISTRIBUTED", + "torch/csrc/deploy/example/generated/uses_distributed"); + torch::deploy::InterpreterManager m(1); + torch::deploy::Package p = m.load_package(model_filename); + { + auto I = p.acquire_session(); + I.self.attr("import_module")({"uses_distributed"}); + } +} diff --git a/torch/csrc/deploy/test_deploy_python_ext.cpp b/torch/csrc/deploy/test_deploy_python_ext.cpp index 42700ead6678b..59a04f5e84853 100644 --- a/torch/csrc/deploy/test_deploy_python_ext.cpp +++ b/torch/csrc/deploy/test_deploy_python_ext.cpp @@ -7,7 +7,7 @@ bool run() { torch::deploy::InterpreterManager m(2); m.register_module_source("check_none", "check = id(None)\n"); - int64_t id0, id1; + int64_t id0 = 0, id1 = 0; { auto I = m.all_instances()[0].acquire_session(); id0 = I.global("check_none", "check").toIValue().toInt(); diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 76f2eaebe5f77..e6522c33280a9 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -359,7 +359,7 @@ void DistEngine::execute_graph_task_until_ready_queue_empty( continue; } if (task.fn_ && !local_graph_task->has_error_.load()) { - AutoGradMode grad_mode(local_graph_task->grad_mode_); + at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); try { GraphTaskGuard guard(local_graph_task); engine_.evaluate_function( diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 0d82c07835f55..a492d9847fb37 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -15,7 +15,7 @@ RecvRpcBackward::RecvRpcBackward( const AutogradMetadata& autogradMetadata, ContextPtr autogradContext, rpc::worker_id_t fromWorkerId, - std::unordered_map deviceMap) + rpc::DeviceMap deviceMap) : autogradMetadata_(autogradMetadata), // NOLINTNEXTLINE(performance-move-const-arg) autogradContext_(std::move(autogradContext)), diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h index 46bdb297cdf46..6e6678b128985 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h @@ -23,7 +23,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { const AutogradMetadata& autogradMetadata, std::shared_ptr autogradContext, rpc::worker_id_t fromWorkerId, - std::unordered_map deviceMap); + rpc::DeviceMap deviceMap); torch::autograd::variable_list apply( torch::autograd::variable_list&& grads) override; @@ -41,7 +41,7 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { rpc::worker_id_t fromWorkerId_; // Device mapping for tensors sent over RPC. - const std::unordered_map deviceMap_; + const rpc::DeviceMap deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index 4d84e99753961..b8d28f7be7c2d 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -19,7 +19,7 @@ RpcWithAutograd::RpcWithAutograd( MessageType messageType, const AutogradMetadata& autogradMetadata, c10::intrusive_ptr wrappedMessage, - std::unordered_map deviceMap) + rpc::DeviceMap deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), @@ -39,7 +39,7 @@ RpcWithAutograd::RpcWithAutograd( std::unique_ptr wrappedRpc, MessageType wrappedMessageType, std::vector tensors, - std::unordered_map deviceMap) + rpc::DeviceMap deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), @@ -112,7 +112,7 @@ std::unique_ptr RpcWithAutograd::fromMessage( auto c10DeviceMap = tupleElements[4].to>(); // Convert to regular map. - std::unordered_map deviceMap; + rpc::DeviceMap deviceMap; for (const auto& mapEntry : c10DeviceMap) { deviceMap.insert({mapEntry.key(), mapEntry.value()}); } @@ -169,7 +169,7 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const { return fromWorkerId_; } -const std::unordered_map& RpcWithAutograd:: +const rpc::DeviceMap& RpcWithAutograd:: deviceMap() { return deviceMap_; } diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h index 1884cc9742939..6d0b6111cc88c 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h @@ -19,7 +19,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { rpc::MessageType messageType, const AutogradMetadata& autogradMetadata, c10::intrusive_ptr wrappedMessage, - std::unordered_map deviceMap = {}); + rpc::DeviceMap deviceMap = {}); // Used when receiving an RPC over the wire. RpcWithAutograd( @@ -29,7 +29,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::unordered_map deviceMap = {}); + rpc::DeviceMap deviceMap = {}); c10::intrusive_ptr toMessageImpl() && override; @@ -55,7 +55,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { rpc::worker_id_t fromWorkerId() const; // Retrieve the device map. - const std::unordered_map& deviceMap(); + const rpc::DeviceMap& deviceMap(); private: // WorkerId from which this RPC originated. This is necessary for knowing @@ -90,7 +90,7 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { std::vector tensors_; // Device mapping for tensors that are sent across an RPC to another node. - std::unordered_map deviceMap_; + rpc::DeviceMap deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 4e29bfcc1ffe9..9db40766c598a 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -53,7 +53,7 @@ ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, rpc::worker_id_t fromWorkerId, - const std::unordered_map& deviceMap) { + const rpc::DeviceMap& deviceMap) { // Initialize autograd context if necessary. auto& autogradContainer = DistAutogradContainer::getInstance(); auto autogradContext = @@ -105,7 +105,7 @@ c10::intrusive_ptr getMessageWithAutograd( c10::intrusive_ptr wrappedRpcMsg, MessageType msgType, bool forceGradRecording, - const std::unordered_map& deviceMap) { + const rpc::DeviceMap& deviceMap) { auto& autogradContainer = DistAutogradContainer::getInstance(); // If there is no valid context and no tensor requires grads, send original diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index fae675d3b81c6..94883ce605269 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -31,7 +31,7 @@ TORCH_API ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, rpc::worker_id_t fromWorkerId, - const std::unordered_map& deviceMap); + const rpc::DeviceMap& deviceMap); // This method is a wrapper utility used internally to wrap autograd info // and attach autograd function for each type of rpc call if it has valid @@ -44,7 +44,7 @@ TORCH_API c10::intrusive_ptr getMessageWithAutograd( c10::intrusive_ptr wrappedRpcMsg, rpc::MessageType msgType, bool forceGradRecording = false, - const std::unordered_map& deviceMap = + const rpc::DeviceMap& deviceMap = {}); // Send message after autograd checking diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index e3ee14da0f542..bd50bba3606b9 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -50,6 +50,15 @@ const inline char* getNcclErrorDetailStr(ncclResult_t error) { #define ENABLE_NCCL_P2P_SUPPORT #endif +// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.9.7+ +#if (defined(__CUDA_BF16_TYPES_EXIST__) && \ + defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ + (defined(NCCL_MINOR) && ((NCCL_MINOR > 9) || \ + ((NCCL_MINOR == 9) && defined(NCCL_PATCH) && (NCCL_PATCH >= 7))))) || \ + (defined(__HIP_PLATFORM_HCC__) && (TORCH_HIP_VERSION >= 301)) +#define ENABLE_NCCL_BF16_DATATYPE +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd) \ do { \ diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index ba26409c9b990..b8f5aa3989ce4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -923,7 +923,7 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( std::vector& inputs, const BroadcastOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::broadcast: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::broadcast: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -1414,7 +1414,7 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( std::vector& inputs, const AllreduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::allreduce: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::allreduce: " + msg); }; assertNonEmpty(invalidArgument, inputs); @@ -1475,7 +1475,7 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument( + TORCH_CHECK(false, "ProcessGroupGloo::allreduce_coalesced: " + msg); }; assertNonEmpty(invalidArgument, tensors); @@ -1644,7 +1644,7 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( std::vector& inputs, const ReduceOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::reduce: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::reduce: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -1821,7 +1821,7 @@ c10::intrusive_ptr ProcessGroupGloo::allgather( std::vector& inputs, const AllgatherOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::allgather: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::allgather: " + msg); }; if (inputs.size() == 0) { @@ -1955,7 +1955,7 @@ c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector& input_list, const AllgatherOptions& /* unused */) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument( + TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg); }; @@ -2152,7 +2152,7 @@ c10::intrusive_ptr ProcessGroupGloo::gather( std::vector& inputs, const GatherOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::gather: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::gather: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -2336,7 +2336,7 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( std::vector>& inputs, const ScatterOptions& opts) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::scatter: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::scatter: " + msg); }; assertRootRank(invalidArgument, opts.rootRank, size_); @@ -2530,7 +2530,7 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( std::vector& inputCounts, const AllToAllOptions& /* unused */) { static auto invalidArgument = [](const std::string& msg) { - throw std::invalid_argument("ProcessGroupGloo::alltoall_base: " + msg); + TORCH_CHECK(false, "ProcessGroupGloo::alltoall_base: " + msg); }; TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 036ce91b85faf..5c0c76afa2453 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -318,6 +318,10 @@ class TORCH_API ProcessGroupGloo : public ProcessGroup { // may indicate that there is some sort of collective desynchronization. uint64_t getSequenceNumberForGroup() override; + int getNumThreads() { + return options_->threads; + } + protected: std::unique_ptr<::gloo::rendezvous::Store> store_; const c10::intrusive_ptr options_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index aa6d81bbe4a13..b75f4417e832a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3c7041a2dd691..9773b350e2cd7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -63,7 +64,7 @@ std::map ncclDataType = { {at::kLong, ncclInt64}, {at::kHalf, ncclHalf}, {at::kBool, ncclUint8}, -#if defined(__HIP_PLATFORM_HCC__) && TORCH_HIP_VERSION >= 301 +#if defined(ENABLE_NCCL_BF16_DATATYPE) {at::kBFloat16, ncclBfloat16}, #endif }; @@ -189,6 +190,17 @@ std::string getExceptionMsgFromExceptionPtr( } } +inline void errorIfCapturingNonCapturableNCCL() { + auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); + // parentheses avoid some compiler warnings + static const uint64_t min_version = (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); + static const uint64_t cur_version = torch::cuda::nccl::version(); + if (cur_version < min_version) { + TORCH_CHECK(status == c10::cuda::CaptureStatus::None, + "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6"); + } +} + } // namespace const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000; @@ -1079,6 +1091,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle) { + errorIfCapturingNonCapturableNCCL(); + // Bump collective counter if (sequenceNum_) { sequenceNum_->increment(); diff --git a/torch/csrc/distributed/c10d/comm.hpp b/torch/csrc/distributed/c10d/comm.hpp index 9b45795683004..4690c355ce71b 100644 --- a/torch/csrc/distributed/c10d/comm.hpp +++ b/torch/csrc/distributed/c10d/comm.hpp @@ -18,12 +18,14 @@ class TORCH_API GradBucket { public: explicit GradBucket( size_t index, + size_t bucket_count, const at::Tensor& tensor, const std::vector& offsets, const std::vector& lengths, const std::vector& sizes_vec, const std::vector& parameters) : index_(index), + bucket_count_(bucket_count), buffer_(tensor), offsets_(offsets), lengths_(lengths), @@ -63,11 +65,12 @@ class TORCH_API GradBucket { // Returns whther this bucket is the last bucket to allreduce in an iteration. bool isLast() const { - return index_ == 0; + return index_ == bucket_count_ - 1; } private: size_t index_; + size_t bucket_count_; at::Tensor buffer_; // Per-variable info in buffer_. diff --git a/torch/csrc/distributed/c10d/default_comm_hooks.cpp b/torch/csrc/distributed/c10d/default_comm_hooks.cpp index 9d13099c424c6..30bc96b16f7db 100644 --- a/torch/csrc/distributed/c10d/default_comm_hooks.cpp +++ b/torch/csrc/distributed/c10d/default_comm_hooks.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include @@ -16,21 +18,28 @@ c10::intrusive_ptr AllReduceCommHook::runHook( c10::intrusive_ptr FP16CompressCommHook::runHook( GradBucket& bucket) { - auto& tensor = bucket.getBufferRef(); - tensor.copy_(tensor.to(torch::kFloat16)); - std::vector tensors = {tensor}; + + auto compressed_tensor = bucket.getBufferRef().to(torch::kFloat16); // Apply the division first to avoid overflow. - tensors[0] /= state_->getSize(); + compressed_tensor /= state_->getSize(); + std::vector tensors = {compressed_tensor}; auto allreduce_fut = state_->allreduce(tensors)->getFuture(); - auto decompress = [](c10::ivalue::Future& allreduce_fut) { + auto decompressed_tensor = bucket.getBufferRef(); + auto decompress = [decompressed_tensor](c10::ivalue::Future& allreduce_fut) { auto result = allreduce_fut.value(); TORCH_INTERNAL_ASSERT( result.isTensorList(), "ProcessGroup::allreduce should return TensorList"); + auto reduce_tensor = result.toTensorVector()[0]; - reduce_tensor.copy_(reduce_tensor.to(torch::kFloat)); - return c10::IValue(reduce_tensor); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + reduce_tensor.scalar_type() == at::ScalarType::Half, + "Expected reduced tensor to be fp16 in FP16CompressHook, but got type ", + reduce_tensor.scalar_type() + ); + decompressed_tensor.copy_(reduce_tensor); + return c10::IValue(decompressed_tensor); }; return allreduce_fut->then(decompress, allreduce_fut->elementType()); diff --git a/torch/csrc/distributed/c10d/frontend.cpp b/torch/csrc/distributed/c10d/frontend.cpp index b65cba79884af..e5b59f28982f6 100644 --- a/torch/csrc/distributed/c10d/frontend.cpp +++ b/torch/csrc/distributed/c10d/frontend.cpp @@ -3,10 +3,11 @@ #include #include #include -#include #include #include #include +#include +#include #include #include @@ -17,10 +18,6 @@ #include #endif -#ifdef USE_C10D_NCCL -#include -#endif - #ifdef USE_C10D_MPI #include #endif @@ -29,6 +26,20 @@ namespace c10d { namespace { +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerNCCLProcessGroupProvider +const NCCLProcessGroupProvider stubProvider; +constexpr const NCCLProcessGroupProvider* defaultStubProviderAddr = + &stubProvider; +inline const NCCLProcessGroupProvider*& getNCCLProcessGroupProviderAddress() { + static const NCCLProcessGroupProvider* stubs_ = defaultStubProviderAddr; + return stubs_; +} + +const NCCLProcessGroupProvider* GetNCCLProcessGroupProvider() { + return getNCCLProcessGroupProviderAddress(); +} + void maybePreprocessComplexTensor(at::Tensor& tensor) { if(!tensor.is_complex()) { return; @@ -63,6 +74,11 @@ void assertReduceOpSupportsComplexTensor(ReduceOp op) { } // namespace anonymous +void registerNCCLProcessGroupProvider(NCCLProcessGroupProvider* provider) { + getNCCLProcessGroupProviderAddress() = provider; +} + + std::string Backend::get(const std::string& backend_type) { return backend_type; } @@ -207,17 +223,7 @@ c10::intrusive_ptr DistributedC10d::newProcessGroupHelper( "Attempting to create GLOO-based process group while GLOO is either not enabled or built"); #endif // USE_C10D_GLOO } else if (backend == "nccl") { -#ifdef USE_C10D_NCCL - auto options = ProcessGroupNCCL::Options::create(); - - options->is_high_priority_stream = false; - options->timeout = timeout; - pg = c10::make_intrusive( - prefix_store, rank, world_size, options); -#else - AT_ERROR( - "Attempting to create NCCL-based process group while NCCL is either not enabled or built"); -#endif // USE_C10D_NCCL + pg = GetNCCLProcessGroupProvider()->get(prefix_store, rank, world_size, timeout); } else { // TODO: discuss to figure out how to extend this to third party backends? AT_ERROR("Unsupported backend type: ", backend); @@ -1008,7 +1014,7 @@ void initCustomClassBindings() { .def( "broadcast", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector data) { return self->broadcast(data); + std::vector data) { return self->broadcast(data); }) */ .def( @@ -1045,14 +1051,14 @@ void initCustomClassBindings() { .def( "allreduce", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - at::Tensor& tensor, - c10::intrusive_ptr<::c10d::ReduceOp> op) { + at::Tensor& tensor, + c10::intrusive_ptr<::c10d::ReduceOp> op) { ::c10d::AllreduceOptions opts; opts.reduceOp = *op; std::vector tensors = {tensor}; return self->allreduce(tensors, opts); - } - ) + } + ) */ // TODO: make AllreduceCoalescedOptions compatible with TorchBind to // provide the full API in python. @@ -1098,8 +1104,8 @@ void initCustomClassBindings() { .def( "allgather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector output, - at::Tensor input) { + std::vector output, + at::Tensor input) { std::vector> outputs = { std::move(output)}; std::vector inputs = {std::move(input)}; @@ -1121,8 +1127,8 @@ void initCustomClassBindings() { .def( "gather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector> output_tensors, - std::vector input_tensors) { + std::vector> output_tensors, + std::vector input_tensors) { ::c10d::GatherOptions opts; return self->gather(output_tensors, input_tensors, opts); }) @@ -1145,8 +1151,8 @@ void initCustomClassBindings() { .def( "scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector outputTensors, - std::vector> inputTensors) { + std::vector outputTensors, + std::vector> inputTensors) { ::c10d::ScatterOptions opts; self->scatter(outputTensors, inputTensors, opts); }) @@ -1169,8 +1175,8 @@ void initCustomClassBindings() { // TODO: Enable this method when TorchBind supports ReduceScatterOptions. .def( "reduce_scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector outputTensors, - std::vector> inputTensors) { + std::vector outputTensors, + std::vector> inputTensors) { ::c10d::ReduceScatterOptions opts; return self->reduce_scatter(outputTensors, inputTensors, opts); }) @@ -1241,95 +1247,6 @@ void initCustomClassBindings() { return self->barrier(opts); }); -#ifdef USE_C10D_NCCL - // XXX: Ideally the Options of ProcessGroupNCCL should be - // bound using `def_readwrite` like in pybind11, but we - // didn't do that because: 1. no milisecond support yet - // 2. no def_readwrite or property support yet. - // TODO: make this binding the same as pybind11 - static const auto ProcessGroupNCCLOptionsTorchBind = - torch::class_<::c10d::ProcessGroupNCCL::Options>( - "dist_c10d", "ProcessGroupNCCLOptions") - .def(torch::init([](int64_t timeout, bool isHighPriorityStream) { - auto opTimeout = std::chrono::milliseconds(timeout); - auto opts = - ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream); - opts->timeout = opTimeout; - return opts; - })); - - static const auto ProcessGroupNCCLTorchBind = - torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL") - .def_pickle( - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - auto base_process_group = - ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>(self); - auto name = - ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); - return std::vector{name}; - }, - [](std::vector state) { - TORCH_CHECK( - state.size() == 1, - "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ", - state.size()); - const auto& process_group_name = state.front(); - auto base_process_group = - ::c10d::DistributedC10d::get()->getProcessGroupByName( - process_group_name); - TORCH_CHECK( - base_process_group.defined(), - "Needed process group not found, ", - "please create a process group with name: ", - process_group_name); - c10::intrusive_ptr<::c10d::ProcessGroupNCCL> - process_group_nccl = ::c10::dynamic_intrusive_pointer_cast< - ::c10d::ProcessGroupNCCL>(base_process_group); - TORCH_CHECK( - process_group_nccl.defined(), - "Process group ", - process_group_name, - " isn't configured for NCCL backend"); - return process_group_nccl; - }) - .def(torch::init( - [](const c10::intrusive_ptr<::c10d::Store>& store, - int64_t rank, - int64_t size, - c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options, - const std::string& name) { - auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>( - store, rank, size, options); - ::c10d::DistributedC10d::get()->registerProcessGroupName( - pg, name); - return pg; - })) - .def( - "alltoall_base", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, - at::Tensor output, - at::Tensor input, - std::vector outputSplitSizes, - std::vector inputSplitSizes) { - return self->alltoall_base( - output, - input, - outputSplitSizes, - inputSplitSizes, - ::c10d::AllToAllOptions()); - }) - .def( - "size", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - return (int64_t)self->getSize(); - }) - .def( - "rank", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - return (int64_t)self->getRank(); - }); -#endif - static const auto DistributedC10dFrontendTorchBind = torch::class_<::c10d::DistributedC10d>("dist_c10d", "frontend") .def(torch::init([]() { return ::c10d::DistributedC10d::get(); })) @@ -1344,4 +1261,12 @@ void initCustomClassBindings() { &::c10d::DistributedC10d::getNameOfProcessGroup); } +TORCH_LIBRARY(q, m) { + m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); + m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor"); +} +TORCH_LIBRARY_IMPL(q, CPU, m) { + m.impl("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cpu); + m.impl("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cpu); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/frontend.hpp b/torch/csrc/distributed/c10d/frontend.hpp index c90cc077b2823..b39d8b7a444bf 100644 --- a/torch/csrc/distributed/c10d/frontend.hpp +++ b/torch/csrc/distributed/c10d/frontend.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -259,7 +260,26 @@ class TORCH_PYTHON_API DistributedC10d : public torch::CustomClassHolder { int64_t group_count_; }; -// Must be called to initialize Torchbind bindings for c10d. -void initCustomClassBindings(); +// This class exists as a way to allow us to split NCCL-specific code into a +// different file. frontend_cuda.cpp will, if USE_C10D_NCCL is defined, +// override this NCCLProcessGroupProvider with one that will actually do +// something. +struct TORCH_API NCCLProcessGroupProvider { + virtual c10::intrusive_ptr get( + c10::intrusive_ptr /*prefix_store*/, + int64_t /*rank*/, + int64_t /*world_size*/, + std::chrono::milliseconds /*timeout*/) const { + AT_ERROR( + "Attempting to create NCCL-based process group while NCCL is either not enabled or built"); + } + + virtual ~NCCLProcessGroupProvider() = default; +}; + +TORCH_API void registerNCCLProcessGroupProvider( + NCCLProcessGroupProvider* provider); + +TORCH_API void initCustomClassBindings(); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/frontend_cuda.cpp b/torch/csrc/distributed/c10d/frontend_cuda.cpp new file mode 100644 index 0000000000000..1b42f13b3c8df --- /dev/null +++ b/torch/csrc/distributed/c10d/frontend_cuda.cpp @@ -0,0 +1,136 @@ +#include + +#ifdef USE_C10D_NCCL + +#include +#include +#include +#include +#include + +namespace c10d { + +void initCustomClassBindingsNccl() { + // XXX: Ideally the Options of ProcessGroupNCCL should be + // bound using `def_readwrite` like in pybind11, but we + // didn't do that because: 1. no milisecond support yet + // 2. no def_readwrite or property support yet. + // TODO: make this binding the same as pybind11 + static const auto ProcessGroupNCCLOptionsTorchBind = + torch::class_<::c10d::ProcessGroupNCCL::Options>( + "dist_c10d", "ProcessGroupNCCLOptions") + .def(torch::init([](int64_t timeout, bool isHighPriorityStream) { + auto opTimeout = std::chrono::milliseconds(timeout); + auto opts = + ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream); + opts->timeout = opTimeout; + return opts; + })); + + static const auto ProcessGroupNCCLTorchBind = + torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL") + .def_pickle( + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + auto base_process_group = + ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>( + self); + auto name = + ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); + return std::vector{name}; + }, + [](std::vector state) { + TORCH_CHECK( + state.size() == 1, + "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ", + state.size()); + const auto& process_group_name = state.front(); + auto base_process_group = + ::c10d::DistributedC10d::get()->getProcessGroupByName( + process_group_name); + TORCH_CHECK( + base_process_group.defined(), + "Needed process group not found, ", + "please create a process group with name: ", + process_group_name); + c10::intrusive_ptr<::c10d::ProcessGroupNCCL> + process_group_nccl = ::c10::dynamic_intrusive_pointer_cast< + ::c10d::ProcessGroupNCCL>(base_process_group); + TORCH_CHECK( + process_group_nccl.defined(), + "Process group ", + process_group_name, + " isn't configured for NCCL backend"); + return process_group_nccl; + }) + .def(torch::init( + [](const c10::intrusive_ptr<::c10d::Store>& store, + int64_t rank, + int64_t size, + c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options, + const std::string& name) { + auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>( + store, rank, size, options); + ::c10d::DistributedC10d::get()->registerProcessGroupName( + pg, name); + return pg; + })) + .def( + "alltoall_base", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, + at::Tensor output, + at::Tensor input, + std::vector outputSplitSizes, + std::vector inputSplitSizes) { + return self->alltoall_base( + output, + input, + outputSplitSizes, + inputSplitSizes, + ::c10d::AllToAllOptions()); + }) + .def( + "size", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t)self->getSize(); + }) + .def( + "rank", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t)self->getRank(); + }); +} + +namespace { +struct RealNCCLProcessGroupProvider : public NCCLProcessGroupProvider { + c10::intrusive_ptr get( + c10::intrusive_ptr prefix_store, + int64_t rank, + int64_t world_size, + std::chrono::milliseconds timeout) const override { + auto options = ProcessGroupNCCL::Options::create(); + options->is_high_priority_stream = false; + options->timeout = timeout; + return c10::make_intrusive( + prefix_store, rank, world_size, options); + } +}; + +struct RegisterNCCLProcessGroupProvider { + RegisterNCCLProcessGroupProvider() { + static RealNCCLProcessGroupProvider provider; + registerNCCLProcessGroupProvider(&provider); + } +}; + +RegisterNCCLProcessGroupProvider reg; + +} // namespace +#define DISPATCH_TO_CUDA(name, function) \ + m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) +TORCH_LIBRARY_IMPL(q, CUDA, m) { + DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cuda); + DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cuda); +} +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/frontend_cuda.hpp b/torch/csrc/distributed/c10d/frontend_cuda.hpp new file mode 100644 index 0000000000000..a790f2e847b0d --- /dev/null +++ b/torch/csrc/distributed/c10d/frontend_cuda.hpp @@ -0,0 +1,12 @@ +#pragma once + +#ifdef USE_C10D_NCCL +#include + +namespace c10d { + +TORCH_API void initCustomClassBindingsNccl(); + +} + +#endif diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 201f0c2dd64f4..4bac0ca46edc4 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -17,6 +17,7 @@ #ifdef USE_C10D_NCCL #include +#include #endif #ifdef USE_C10D_MPI @@ -31,6 +32,7 @@ #include #include #include + #include #include #include @@ -230,6 +232,9 @@ void _register_builtin_comm_hook( PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); ::c10d::initCustomClassBindings(); +#ifdef USE_C10D_NCCL + ::c10d::initCustomClassBindingsNccl(); +#endif auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed")); if (!c10d_module) { @@ -1643,7 +1648,6 @@ static PyMethodDef methods[] = { // NOLINT PyMethodDef* python_functions() { return methods; } - } // namespace c10d } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 9fa7289c16568..92e16614a6612 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -4,11 +4,18 @@ #include #include +#ifdef USE_C10D_GLOO +#include +#endif + namespace c10d { -// When training runs at these iterations, log the runtime -// stats. -const int LoggingIterations[] = {10, 20, 100, 1000}; +// Logs runtime stats to configured destination. Note that since data collection +// only runs every ddp_runtime_logging_sample_rate iterations, the actual +// training iterations recorded will be like 10, +// (20-10) * ddp_runtime_logging_sample_rate, +// (50-10) * ddp_runtime_logging_sample_rate and so on. +const int LoggingIterations[] = {10, 20, 50, 100, 500, 800, 1000}; // NOLINT std::ostream& operator<<(std::ostream& output, const Logger& logger) { auto& ddp_logging_data = (*logger.ddp_logging_data_); @@ -68,6 +75,13 @@ void Logger::set_env_variables() { parse_env("GLOO_SOCKET_IFNAME"); ddp_logging_data_->strs_map["gloo_device_transport"] = parse_env("GLOO_DEVICE_TRANSPORT"); + + #ifdef USE_C10D_GLOO + auto gloo_pg = + static_cast(reducer_->process_group_.get()); + auto n_threads = gloo_pg->getNumThreads(); + ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads; + #endif } } diff --git a/torch/csrc/distributed/c10d/quantization/quantization.cpp b/torch/csrc/distributed/c10d/quantization/quantization.cpp new file mode 100644 index 0000000000000..b9682d73ed139 --- /dev/null +++ b/torch/csrc/distributed/c10d/quantization/quantization.cpp @@ -0,0 +1,93 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace c10d { +namespace quantization { + +void FloatToBFloat16Quantized_ref( + const float* const input, + const size_t nrows, + const size_t ncols, + uint16_t* const output){ + for (const auto row : c10::irange(nrows)) { + const float* input_row = input + row * ncols; + uint16_t* output_row = output + row * ncols; + + for (const auto col : c10::irange(ncols)) { + output_row[col] = + (*reinterpret_cast(input_row + col) + (1 << 15)) >> + 16; + } + } +} + +void BFloat16QuantizedToFloat_ref( + const at::BFloat16* const input, + const size_t nrows, + const size_t ncols, + float* const output){ + const int32_t output_columns = ncols; + + for (const auto row : c10::irange(nrows)) { + const at::BFloat16* input_row = input + row * ncols; + float* output_row = output + row * output_columns; + + for (const auto col : c10::irange(ncols)) { + uint32_t val_fp32 = static_cast( + reinterpret_cast(input_row)[col]) + << 16; + reinterpret_cast(output_row)[col] = val_fp32; + } + } +} + +at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) { + TENSOR_ON_CPU(input); + // Currently it supports 2D inputs + TENSOR_NDIM_EQUALS(input, 2); + + const auto input_sizes = input.sizes(); + const int32_t nrows = input_sizes[0]; + const int32_t ncols = input_sizes[1]; + const int32_t output_columns = ncols; + auto output = at::empty( + {nrows, output_columns}, + input.options().dtype(at::kHalf)); + + FloatToBFloat16Quantized_ref( + input.data_ptr(), + nrows, + ncols, + reinterpret_cast(output.data_ptr())); + + return output; +} + +at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) { + TENSOR_ON_CPU(input); + // Currently it supports 2D inputs + TENSOR_NDIM_EQUALS(input, 2); + + const auto input_sizes = input.sizes(); + const int32_t nrows = input_sizes[0]; + const int32_t ncols = input_sizes[1]; + const int32_t output_columns = ncols; + + auto output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); // + BFloat16QuantizedToFloat_ref( + reinterpret_cast(input.data_ptr()), + nrows, + ncols, + output.data_ptr()); + + return output; +} + +} // namespace quantization +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/c10d/quantization/quantization.h b/torch/csrc/distributed/c10d/quantization/quantization.h new file mode 100644 index 0000000000000..658fa754488d1 --- /dev/null +++ b/torch/csrc/distributed/c10d/quantization/quantization.h @@ -0,0 +1,20 @@ +// (c) Facebook, Inc. and its affiliates. Confidential and proprietary. + +#pragma once + + +#include +#include + +namespace torch { +namespace distributed { +namespace c10d { +namespace quantization { + +at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input); +at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input); + +} // namespace quantization +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu new file mode 100644 index 0000000000000..5590e035b0683 --- /dev/null +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.cu @@ -0,0 +1,148 @@ +#include +#include +#include +#include + +// FP32 -> BF16 kernel +__global__ inline void _float_to_bfloat16_cuda_kernel( + const float* __restrict__ input, + const int nrows, + const int ncols, + uint16_t* __restrict__ output) { + const int row_incre = blockDim.y * gridDim.y; + const int col_incre = blockDim.x * gridDim.x; + for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + row += row_incre) { + const float* input_row = input + row * ncols; + uint16_t* output_row = output + row * ncols; + for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + col += col_incre) { + // Add 2^15 and right shift 16 to do round-nearest + output_row[col] = + (*reinterpret_cast(input_row + col) + (1 << 15)) >> + 16; + } + } +} + +// BF16 -> FP32 kernel +__global__ inline void _bfloat16_to_float_cuda_kernel( + const uint16_t* __restrict__ input, + const int nrows, + const int ncols, + float* __restrict__ output) { + const int row_incre = blockDim.y * gridDim.y; + const int col_incre = blockDim.x * gridDim.x; + for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; + row += row_incre) { + for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; + col += col_incre) { + const uint16_t* input_row = input + row * ncols; + float* output_row = output + row * ncols; + uint32_t val_fp32 = static_cast( + reinterpret_cast(input_row)[col]) + << 16; + reinterpret_cast(output_row)[col] = val_fp32; + } + } +} + +namespace torch { +namespace distributed { +namespace c10d { +namespace quantization { + +at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) { + TENSOR_ON_CUDA_GPU(input); + // Currently it supports 2D inputs + TENSOR_NDIM_EQUALS(input, 2); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const int nrows = input.size(0); + const int ncols = input.size(1); + const int output_columns = ncols; + + auto output = at::empty( + {nrows, output_columns}, + input.options().dtype(at::kHalf)); // at::kHalf + + if (nrows == 0 || output_columns == 0) { + return output; + } + + // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia + // NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16 + + constexpr int threads_per_block = 256; + const int blockDim_x = std::min(output_columns, threads_per_block); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + dim3 gridDim(gridDim_x, gridDim_y); + + _float_to_bfloat16_cuda_kernel<<< + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream()>>>( + input.data_ptr(), + nrows, + ncols, + // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia + // NCCL + reinterpret_cast(output.data_ptr())); + //C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) { + TENSOR_ON_CUDA_GPU(input); + // Currently it supports 2D inputs + TENSOR_NDIM_EQUALS(input, 2); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const int nrows = input.size(0); + const int ncols = input.size(1); + const int output_columns = ncols; + + auto output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); // at::kBytes for uint8_t + + if (nrows == 0 || output_columns == 0) { + return output; + } + + constexpr int threads_per_block = 256; + + const int blockDim_x = std::min(output_columns, threads_per_block); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x; + const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u); + dim3 gridDim(gridDim_x, gridDim_y); + + _bfloat16_to_float_cuda_kernel<<< + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream()>>>( + // TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia + // NCCL + reinterpret_cast(input.data_ptr()), + nrows, + ncols, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +} // namespace quantization +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/c10d/quantization/quantization_gpu.h b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h new file mode 100644 index 0000000000000..2a0c8f8f8d39c --- /dev/null +++ b/torch/csrc/distributed/c10d/quantization/quantization_gpu.h @@ -0,0 +1,20 @@ +// (c) Facebook, Inc. and its affiliates. Confidential and proprietary. + +#pragma once + + +#include +#include + +namespace torch { +namespace distributed { +namespace c10d { +namespace quantization { + +at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input); +at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input); + +} // namespace quantization +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/c10d/quantization/quantization_utils.h b/torch/csrc/distributed/c10d/quantization/quantization_utils.h new file mode 100644 index 0000000000000..0467ba2769f5b --- /dev/null +++ b/torch/csrc/distributed/c10d/quantization/quantization_utils.h @@ -0,0 +1,31 @@ +// (c) Facebook, Inc. and its affiliates. Confidential and proprietary. + +#pragma once + +#include + +#include + +inline std::string torch_tensor_device_name(const at::Tensor& ten) { + return c10::DeviceTypeName(ten.device().type()); +} + +#define TENSOR_NDIM_EQUALS(ten, dims) \ + TORCH_CHECK( \ + (ten).ndimension() == (dims), \ + "Tensor '" #ten "' must have " #dims \ + " dimension(s). " \ + "Found ", \ + (ten).ndimension()) + +#define TENSOR_ON_CPU(x) \ + TORCH_CHECK( \ + !x.is_cuda(), \ + #x " must be a CPU tensor; it is currently on device ", \ + torch_tensor_device_name(x)) + +#define TENSOR_ON_CUDA_GPU(x) \ + TORCH_CHECK( \ + x.is_cuda(), \ + #x " must be a CUDA tensor; it is currently on device ", \ + torch_tensor_device_name(x)) diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index d91f191602888..91db615181e56 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -377,9 +377,25 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) { if (comm_hook_ == nullptr) { auto wrapped = at::native::wrapped_scalar_tensor(double(1.) / div_factor_); - // Divides while copying into the bucket view to save one scan over - // all the input parameters. - at::mul_out(bucket_view, grad, wrapped); + if (!grad.requires_grad()) { + // Divides while copying into the bucket view to save one scan over + // all the input parameters. + at::mul_out(bucket_view, grad, wrapped); + } else { + // If DDP is running with create_graph=True, gradients require_grad + // themselves in order to compute higher order derivatives. However, + // DDP will not sync up these gradients currently (see + // https://github.com/pytorch/pytorch/issues/63812). + LOG(WARNING) + << "Using DistributedDataParallel with create_graph=True " + << " is not well-supported. The higher-order gradient will " + << " not be synchronized across ranks, and backpropagation " + << " through all_reduce operations will not occur. If you require " + << " DDP to work with higher-order gradients for your use case, " + << " please ping https://github.com/pytorch/pytorch/issues/63929"; + auto div_result = at::mul(grad, wrapped); + bucket_view.copy_(div_result); + } } else { bucket_view.copy_(grad); } @@ -456,6 +472,7 @@ std::vector Reducer::get_grad_buckets( auto variables_for_bucket = get_variables_for_bucket(i, bucket); gradBuckets.emplace_back( i, + buckets_.size(), return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents) : bucket.replicas[0].contents, bucket.replicas[0].offsets, @@ -872,6 +889,7 @@ void Reducer::all_reduce_bucket(Bucket& bucket) { auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket); GradBucket grad_bucket( next_bucket_, + buckets_.size(), tensors[0], // Since we only support single-process single-device // mode, there is always only one replica in the bucket. diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 02771140f69bb..7265ed400b2e9 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -68,10 +68,17 @@ void Message::setId(int64_t id) { std::vector> Message::getStorages() const { + // Sparse tensors do not have storage. Instead, a sparse tensor + // contains two tensors indices and values, and both contain storage. std::vector> storages; - storages.reserve(tensors_.size()); + storages.reserve(2 * tensors_.size()); for (const auto& tensor : tensors_) { - storages.emplace_back(tensor.storage().getWeakStorageImpl()); + if (tensor.is_sparse()) { + storages.emplace_back(tensor._indices().storage().getWeakStorageImpl()); + storages.emplace_back(tensor._values().storage().getWeakStorageImpl()); + } else { + storages.emplace_back(tensor.storage().getWeakStorageImpl()); + } } return storages; } diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 93eff094243f8..17a7808912b11 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -101,9 +101,9 @@ enum MessageType { // can then serialize and send tensors chunck-by-chunk, in the streaming // fashion. // type (MessageType): type of the message. -// id (int64_t): message id, this is used by ProcessGroupAgent to match -// request and response. Other implementation can ignore it -// if they have their own ways to do matching. +// id (int64_t): message id, this is used to match request and response. +// Other implementation can ignore it if they have their own +// ways to do matching. // // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, // and PythonResp into a Message, and it is up to the RpcAgent diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 272377166fc5f..60d67c558dcae 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -152,6 +152,27 @@ c10::intrusive_ptr toPyJitFuture( IValue ivalue; try { ivalue = toPyIValue(message); + } catch (py::error_already_set& e) { + py::gil_scoped_acquire acquire; + // FIXME: this is a temporary solution to add a special-case for + // ValueError and TypeError, as those are already used in our + // tests. We should have a more comprehensive coverage for other + // types of exceptions as well. + if (e.matches(PyExc_ValueError)) { + child->setErrorIfNeeded( + std::make_exception_ptr(pybind11::value_error(e.what()))); + } else if (e.matches(PyExc_TypeError)) { + child->setErrorIfNeeded( + std::make_exception_ptr(pybind11::type_error(e.what()))); + } else { + // py::error_already_set requires GIL to destruct, take special + // care. + child->setErrorIfNeeded( + std::make_exception_ptr(std::runtime_error(e.what()))); + } + e.restore(); + PyErr_Clear(); + return; } catch (std::exception& e) { child->setErrorIfNeeded(std::current_exception()); return; diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 7001209be9851..5fbe63ede321c 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 07d5c61e0c53c..9e16061e0ad42 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -290,7 +290,7 @@ c10::intrusive_ptr RequestCallbackNoPython:: // Need to reverse the device map for the backward pass of distributed // autograd. - std::unordered_map reverseDeviceMap; + DeviceMap reverseDeviceMap; for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); } @@ -582,7 +582,7 @@ c10::intrusive_ptr RequestCallbackNoPython::runJitOperator( std::vector streams) const { c10::MultiStreamGuard guard(streams); try { - op.getOperation()(&stack); + op.getOperation()(stack); } catch (const std::exception&) { return asFuture(std::current_exception()); } diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index a83e77bfe56f9..7cd228e57da8e 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -164,7 +164,7 @@ class TORCH_API RpcAgent { const WorkerInfo& to, c10::intrusive_ptr message, const float rpcTimeoutSeconds = kUnsetRpcTimeout, - const std::unordered_map& deviceMap = {}) = 0; + const DeviceMap& deviceMap = {}) = 0; // Retries sending the message up to maxRetries times until an ACK is // receieved. The duration between consecutive sends is increased over diff --git a/torch/csrc/distributed/rpc/rref_proto.cpp b/torch/csrc/distributed/rpc/rref_proto.cpp index 6f059b1022db0..49e3287f5d778 100644 --- a/torch/csrc/distributed/rpc/rref_proto.cpp +++ b/torch/csrc/distributed/rpc/rref_proto.cpp @@ -46,20 +46,6 @@ const RRefId& RRefMessageBase::rrefId() { return rrefId_; } -c10::intrusive_ptr RRefMessageBase::toMessageImpl() && { - return fromIValues({rrefId_.toIValue()}, type_); -} - -at::IValue RRefMessageBase::fromMessage( - const Message& message, - MessageType type) { - auto values = toIValues(message, type); - - TORCH_INTERNAL_ASSERT( - values.size() == 1, "ScriptUserDelete expects 1 IValue from message."); - return std::move(values.back()); -} - /////////////////////////// ForkMessageBase ////////////////////////////////// const ForkId& ForkMessageBase::forkId() { @@ -76,7 +62,7 @@ std::pair ForkMessageBase::fromMessage( auto ivalues = toIValues(message, type); TORCH_INTERNAL_ASSERT( - ivalues.size() == 2, "ScriptUserDelete expects 2 IValue from message."); + ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message."); return std::make_pair( RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1])); diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index d5a82c21f8632..4ce8066dfe1f7 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -22,9 +22,6 @@ class TORCH_API RRefMessageBase : public RpcCommandBase { const RRefId& rrefId(); - c10::intrusive_ptr toMessageImpl() && override; - static at::IValue fromMessage(const Message& message, MessageType type); - protected: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const RRefId rrefId_; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index df42248639f94..3769db054ab45 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -16,12 +16,6 @@ #include #include -#if TENSORPIPE_HAS_SHM_TRANSPORT -// Needed for ::getpid(), which is used to create a unique address. -#include -#include -#endif - namespace torch { namespace distributed { namespace rpc { @@ -54,7 +48,7 @@ std::vector getDevicesForTensors( "Request device mapping is not available for destination ", remoteName); std::vector devices; - devices.reserve(tensors.size()); + devices.reserve(2 * tensors.size()); bool hasMappedDevice = false; for (const auto& t : tensors) { if (t.device().is_cpu()) { @@ -73,7 +67,12 @@ std::vector getDevicesForTensors( " for device ", t.device(), " but received a tensor on that device."); - devices.push_back(deviceIter->second); + if (t.is_sparse()) { + devices.push_back(deviceIter->second); + devices.push_back(deviceIter->second); + } else { + devices.push_back(deviceIter->second); + } hasMappedDevice = true; } } @@ -209,22 +208,10 @@ C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport); #if TENSORPIPE_HAS_SHM_TRANSPORT -std::string createUniqueShmAddr() { - thread_local uint32_t threadLocalId = 0; - return c10::str( - "shm://tensorpipe_rpc_agent_", - std::this_thread::get_id(), - "_", - ::getpid(), - "_", - threadLocalId++); -} - std::unique_ptr makeShmTransport() { auto context = tensorpipe::transport::shm::create(); - std::string address = createUniqueShmAddr(); - return std::make_unique(TransportRegistration{ - std::move(context), kShmTransportPriority, std::move(address)}); + return std::make_unique( + TransportRegistration{std::move(context), kShmTransportPriority, ""}); } // The SHM implements connections using ringbuffers residing in anonymous shared diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index ee66f3108e522..aa21fdf65c0f9 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -311,8 +311,9 @@ c10::intrusive_ptr tensorpipeDeserialize( tensors.emplace_back(std::move(t)); } - for (const auto i : c10::irange(tpDescriptor.tensors.size())) { - auto& tensor = tpDescriptor.tensors[i]; + size_t tpDescriptorIndex = 0; + for (size_t i = 0; i < tensors.size(); i++) { + auto& tensor = tpDescriptor.tensors[tpDescriptorIndex]; if (tensor.targetDevice.has_value() && tensor.targetDevice->type == tensorpipe::kCudaDeviceType) { TORCH_INTERNAL_ASSERT( @@ -326,6 +327,11 @@ c10::intrusive_ptr tensorpipeDeserialize( ", but got it on ", tensors[i].device()); } + if (tensors[i].is_sparse()) { + tpDescriptorIndex += 2; + } else { + tpDescriptorIndex += 1; + } } return c10::make_intrusive( diff --git a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp index 72d4d5dfec82e..a2e052535efac 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp @@ -67,7 +67,7 @@ c10::intrusive_ptr FaultyTensorPipeAgent::send( const WorkerInfo& to, c10::intrusive_ptr message, const float rpcTimeoutSeconds, - const std::unordered_map& /* unused */) { + const DeviceMap& /* unused */) { // We only fail control messages that have been specified by the test case. // For all other messages, we just send them without any failures. if (!shouldFailMessage(message->type())) { diff --git a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h index 5d6059747c219..22c732862620a 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h @@ -53,8 +53,7 @@ class TORCH_API FaultyTensorPipeAgent : public TensorPipeAgent { const WorkerInfo& to, c10::intrusive_ptr message, const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, - const std::unordered_map& deviceMap = {}) - override; + const DeviceMap& deviceMap = {}) override; // Add delay to writes void pipeWrite( diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 615abbf300666..820ec31691a0a 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -177,7 +177,7 @@ std::unique_ptr deserializeResponse( // Need to reverse the device map for the backward pass of distributed // autograd. - std::unordered_map reverseDeviceMap; + DeviceMap reverseDeviceMap; for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); } diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index f44c5988caab0..45e18afd20233 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -792,7 +792,7 @@ In practice, the interpreter will allocate one Stack, and it will eventually rea [runtime/operator.h](runtime/operator.h) -The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to lookup the corresponding Operation given the `Node` representing the operator in a `Graph`. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their `Node` to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation. +The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to look up the corresponding Operation given the Node representing the operator in a Graph. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their Node to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a `Node*` and returns an operation. ## Interpreter ## @@ -1282,13 +1282,14 @@ Note the alias set `*`. This is the **wildcard set**. Optimization passes must a This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`. ### Alias Analysis in the IR + [ir/alias_analysis.h](ir/alias_analysis.h) An alias analysis pass consumes the per-operator aliasing information to construct a database of aliasing and mutation relationships in a graph, called `AliasDb`. This section focuses on the alias analysis pass; the public interface to `AliasDb` will be described later. -The core data structure in the AliasDb is called `AliasTracker`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple). +The core data structure in the AliasDb is called `MemoryDAG`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple). -The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `AliasTracker` DAG accordingly. For example, for the node: +The alias analysis pass walks through the nodes in a graph, examining schema `AliasInfo` objects and adding edges in the `MemoryDAG` accordingly. For example, for the node: ``` %output : Tensor = aten::view(%self, %size) ``` @@ -1321,7 +1322,7 @@ A few things to note: The last point demonstrates a key concept: *leaf elements uniquely describe memory locations*. Since a leaf element doesn't point to anything, the memory that backs it must have been freshly allocated by some op. Thus we can use leaf elements to represent disjoint memory locations. -So to determine whether `a` and `b` may alias, we traverse the `AliasTracker` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `AliasTracker` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time. +So to determine whether `a` and `b` may alias, we traverse the `MemoryDAG` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `MemoryDAG` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time. ### Writing optimization passes with `AliasDb` `AliasDb` provides a high-level interface to help people write mutability-safe optimization passes. diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index bcd44a1df343a..3fcc4421891a0 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -46,7 +46,7 @@ struct TORCH_API Method : public torch::IMethod { return function_->graph(); } - const std::string& name() const { + const std::string& name() const override { return function_->name(); } diff --git a/torch/csrc/jit/backends/backend.h b/torch/csrc/jit/backends/backend.h index 941f27bfe2b11..5aae642fa5517 100644 --- a/torch/csrc/jit/backends/backend.h +++ b/torch/csrc/jit/backends/backend.h @@ -9,7 +9,7 @@ namespace torch { namespace jit { namespace { // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) -c10::FunctionSchema getIsAvailableSchema() { +inline c10::FunctionSchema getIsAvailableSchema() { c10::Argument self("self", c10::AnyType::get()); c10::Argument available("available", c10::BoolType::get()); c10::FunctionSchema preprocessor_schema( @@ -23,7 +23,7 @@ c10::FunctionSchema getIsAvailableSchema() { constexpr static auto kBackendsNamespace = "__backends__"; // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) -c10::FunctionSchema getCompileSchema() { +inline c10::FunctionSchema getCompileSchema() { c10::Argument self("self", c10::AnyType::get()); c10::Argument mod("processed", c10::AnyType::get()); auto any_dict_ty = @@ -40,7 +40,7 @@ c10::FunctionSchema getCompileSchema() { } // NOLINTNEXTLINE(clang-diagnostic-unneeded-internal-declaration) -c10::FunctionSchema getExecuteSchema() { +inline c10::FunctionSchema getExecuteSchema() { auto any_list_ty = c10::ListType::create(c10::AnyType::get()); c10::Argument self("self", c10::AnyType::get()); c10::Argument handle("handle", c10::AnyType::get()); diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp index 0533b7d85175f..7d9dc18c12589 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp @@ -31,19 +31,8 @@ class NnapiBackend : public PyTorchBackendInterface { c10::impl::GenericDict compile( c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { - auto dict = processed.toGenericDict(); - - // Prepare weights - auto weights = dict.at("weights").toTensorList(); - for (int i = 0; i < weights.size(); i++) { - weights.set(i, weights.get(i).contiguous()); - } - dict.insert("weights", weights); - - // Save ser_model to member variable - ser_model_ = dict.at("ser_model").toTensor(); - // Wrap procesed in dictionary: {"forward": processed} + auto dict = processed.toGenericDict(); c10::Dict handles( c10::StringType::get(), c10::AnyType::get()); handles.insert("forward", dict); @@ -86,8 +75,7 @@ class NnapiBackend : public PyTorchBackendInterface { fixed_inputs.push_back( tensorInp.get(i).permute({0, 2, 3, 1}).contiguous()); } else { - throw std::exception(); - std::cerr << "Invalid mem_fmt" << std::endl; + TORCH_CHECK(false, "Invalid mem_fmt"); } } @@ -103,9 +91,8 @@ class NnapiBackend : public PyTorchBackendInterface { // TODO: See if it's possible to use those directly. if (fmt == 1) { outputs.set(i, outputs.get(i).permute({0, 3, 1, 2})); - } else if (fmt != 0) { - throw std::exception(); - std::cerr << "Invalid mem_fmt" << std::endl; + } else { + TORCH_CHECK(fmt == 0, "Invalid mem_fmt"); } } @@ -117,8 +104,6 @@ class NnapiBackend : public PyTorchBackendInterface { // and cannot be passed through the handles dictionary std::unique_ptr comp_; c10::List out_templates_; - at::Tensor ser_model_; - mobile::Module shape_compute_module_; // Runs once per model initialization // Cannot be moved to compile(), because init() requires actual inputs @@ -126,19 +111,21 @@ class NnapiBackend : public PyTorchBackendInterface { TORCH_CHECK(comp_ == nullptr); auto dict = handle.toGenericDict(); + // Get ser_model + auto ser_model = dict.at("ser_model").toTensor(); // Load shape computation module std::stringstream ss; auto shape_ptr = dict.at("shape_compute_module").toString(); ss.str(*shape_ptr); - shape_compute_module_ = _load_for_mobile(ss); + auto shape_compute_module = _load_for_mobile(ss); out_templates_ = - shape_compute_module_.run_method("prepare", ser_model_, inputs) + shape_compute_module.run_method("prepare", ser_model, inputs) .toTensorList(); // Create and initialize NnapiComilation object comp_ = std::make_unique(); auto weights = dict.at("weights").toTensorVector(); - comp_->init(ser_model_, weights); + comp_->init(ser_model, weights); } }; diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp index 2f68536b64107..be0dbe18d90d0 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp @@ -96,6 +96,9 @@ c10::IValue preprocess( // transform Python lists to C++ c10::List c10::List weights( py::cast>(nnapi_processed[2])); + for (int i = 0; i < weights.size(); i++) { + weights.set(i, weights.get(i).contiguous()); + } c10::List inp_mem_fmts( py::cast>(nnapi_processed[3])); c10::List out_mem_fmts( diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 009ae21dad6d0..cf8f3787229ce 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -182,8 +182,8 @@ RegisterOperators reg_fusion({ Operator( prim::CudaFusionGroup, [](const Node* node) -> Operation { - return [node](Stack* stack) { - fuser::cuda::runFusionGroup(node, *stack); + return [node](Stack& stack) { + fuser::cuda::runFusionGroup(node, stack); }; }, aliasAnalysisSpecialCase()), @@ -196,7 +196,7 @@ RegisterOperators reg_guard({ // if we would ever return refined tensor, which would change aliasing // analysis, we should update aliasdb pass. [](const Node* node) -> Operation { - return [node](Stack* stack) { + return [node](Stack& stack) { // TODO: check latency here!!!! std::vector types = node->tys(attr::types); const auto num_inputs = types.size(); diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index b260e48b16c3f..46f2f41d07e36 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include // TODO: remove, debugging only @@ -327,7 +328,7 @@ void launchFusion( bool runFusion(const int64_t key, Stack& stack, std::string* code_out) { // Short-circuits if fusion isn't enabled - if (!canFuseOnCPU() && !canFuseOnGPU()) + if (!canFuseOnCPULegacy() && !canFuseOnGPU()) return false; // Acquires the FusionSpec @@ -362,7 +363,7 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) { // Attempts to run fallback if device fusion is disabled if (device.is_cuda() && !canFuseOnGPU()) return false; - if (device.is_cpu() && !canFuseOnCPU()) + if (device.is_cpu() && !canFuseOnCPULegacy()) return false; if (device.is_xpu()) return false; diff --git a/torch/csrc/jit/codegen/fuser/fallback.cpp b/torch/csrc/jit/codegen/fuser/fallback.cpp index 59fe7e6f4fd25..60a5d72f3c439 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.cpp +++ b/torch/csrc/jit/codegen/fuser/fallback.cpp @@ -26,7 +26,7 @@ RegisterOperators reg_fused_operators({Operator( [](const Node* node) -> Operation { int64_t dim = node->i(attr::dim); int64_t num_inputs = node->inputs().size(); - return [dim, num_inputs](Stack* stack) { + return [dim, num_inputs](Stack& stack) { auto result = at::cat( fmap( last(stack, num_inputs), diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp index ec67c4bd83773..ef7e9e0b629d5 100644 --- a/torch/csrc/jit/codegen/fuser/interface.cpp +++ b/torch/csrc/jit/codegen/fuser/interface.cpp @@ -8,15 +8,12 @@ #include #include -C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion"); - namespace torch { namespace jit { namespace detail { -// Note: CPU fusion is currently disabled due to test flakiness -#if defined(FBCODE_CAFFE2) +#ifdef TORCH_ENABLE_LLVM bool cpu_fuser_enabled = true; #else bool cpu_fuser_enabled = false; @@ -37,8 +34,7 @@ void runFusion(const int64_t key, Stack& stack) { } bool canFuseOnCPU() { - return fuser::hasFusionBackend(DeviceType::CPU) && - (detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion); + return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled; } bool canFuseOnGPU() { diff --git a/torch/csrc/jit/frontend/convert_to_ssa.cpp b/torch/csrc/jit/frontend/convert_to_ssa.cpp index 9b86c78c89d41..269c049dae64c 100644 --- a/torch/csrc/jit/frontend/convert_to_ssa.cpp +++ b/torch/csrc/jit/frontend/convert_to_ssa.cpp @@ -93,10 +93,8 @@ struct ControlFlowLoadStores { for (const auto& x : mutated_variables) { auto true_type = true_vars->findInAnyFrame(x); auto false_type = false_vars->findInAnyFrame(x); - auto unified = unifyTypes(true_type, false_type); - if (!unified) { - continue; - } + auto unified = + unifyTypes(true_type, false_type, /*default_to_union=*/true); addBlockOutput(true_block, true_type, x); addBlockOutput(false_block, false_type, x); diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index c91ec7bb634f3..71f534107575f 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -150,8 +150,10 @@ struct ExitTransformer { registerBlockOutputs(if_view.thenBlock(), true_outs); registerBlockOutputs(if_view.elseBlock(), false_outs); for (const auto i : c10::irange(true_outs.size())) { - auto out_type = - unifyTypes(true_outs.at(i)->type(), false_outs.at(i)->type()); + auto out_type = unifyTypes( + true_outs.at(i)->type(), + false_outs.at(i)->type(), + /*default_to_union=*/true); n->addOutput()->setType(*out_type); } } diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index d443f418e6eca..dd29f1eda6412 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -185,7 +185,9 @@ NoneStatus canBeNone(Value* v) { if (v->node()->mustBeNone()) { return ALWAYS; } - if (v->type()->kind() == OptionalType::Kind) { + if (v->type()->kind() == OptionalType::Kind || + (v->type()->kind() == UnionType::Kind && + v->type()->expect()->canHoldType(NoneType::get()))) { return MAYBE; } return NEVER; @@ -385,7 +387,7 @@ struct Environment { std::stringstream why_not; if (!as_simple_value->type()->isSubtypeOfExt(parent_type, &why_not)) { auto error = ErrorReport(loc); - error << "Variable '" << name << "' previously has type " + error << "Variable '" << name << "' previously had type " << simple_parent->type()->repr_str() << " but is now being assigned to a value of type " << as_simple_value->type()->repr_str(); @@ -547,6 +549,7 @@ struct Environment { if (!retval && required) { throwVarNotFoundError(ident, range); } + return retval; } @@ -1010,57 +1013,61 @@ struct to_ir { } void emitReturn(const Return& stmt) { - TypePtr result_type = def_stack_.back().declared_return_type_; - Value* result = emitExpr(stmt.expr(), result_type); + TypePtr declared_return_type = + def_stack_.back().declared_return_type_; // nullptr if not annotated + auto actual_return = emitExpr(stmt.expr(), declared_return_type); + // result type is annotated, every return must convert to that type - if (result_type) { + if (declared_return_type) { // this guard skips implicit conversion from None -> Tensor for the return // type. otherwise forgetting a return a function returning a tensor will // cause a None to be converted to a tensor. - if (!(result_type->isSubtypeOf(TensorType::get()) && - result->type()->isSubtypeOf(NoneType::get()))) { - result = tryConvertToType( + if (!(actual_return->type()->isSubtypeOf(TensorType::get()) && + actual_return->type()->isSubtypeOf(NoneType::get()))) { + actual_return = tryConvertToType( stmt.range(), *graph, - result_type, - result, + declared_return_type, + actual_return, /*allow_conversions=*/true); } - - if (!result->type()->isSubtypeOf(result_type)) { + if (!actual_return->type()->isSubtypeOf(declared_return_type)) { throw ErrorReport(stmt.range()) << "Return value was annotated as having type " - << result_type->repr_str() << " but is actually of type " - << result->type()->repr_str(); + << declared_return_type->repr_str() << " but is actually of type " + << actual_return->type()->repr_str(); } } else { - result_type = def_stack_.back().merged_return_type_; - if (!result_type) { - result_type = result->type(); + declared_return_type = def_stack_.back().merged_return_type_; + if (!declared_return_type) { + declared_return_type = actual_return->type(); } - auto merged_result_type = unifyTypes(result_type, result->type()); - if (!merged_result_type) { + auto merged_return_type = + unifyTypes(declared_return_type, actual_return->type()); + if (!merged_return_type) { throw ErrorReport(stmt.range()) << "Previous return statement returned a value of type " - << result_type->repr_str() + << declared_return_type->repr_str() << " but this return statement returns a value of type " - << result->type()->repr_str(); + << actual_return->type()->repr_str(); } - result_type = merged_result_type.value(); + declared_return_type = merged_return_type.value(); } - AT_ASSERT(result_type); + AT_ASSERT(declared_return_type); - def_stack_.back().merged_return_type_ = result_type; + def_stack_.back().merged_return_type_ = declared_return_type; // If the annotated return type is Any and the result type is not Any, // cast the result to Any to facilitate type unification between return // statements on different code paths (e.g. different branches of an if, // body and containing scope of a loop). - if (result_type == AnyType::get() && result->type() != AnyType::get()) { - result = graph->insertUncheckedCast(result, result_type); + if (declared_return_type == AnyType::get() && + actual_return->type() != AnyType::get()) { + actual_return = + graph->insertUncheckedCast(actual_return, declared_return_type); } - graph->insertNode(graph->create(prim::ReturnStmt, {result}, 0)); + graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0)); exit_blocks.insert(environment_stack->block()); } @@ -1142,10 +1149,10 @@ struct to_ir { return {}; } // statement must be var {is, is not} None - auto name = Var(lhs).name().name(); - // XXX - while it should in theory be possible to specialize - // the `x is None` to know x has type NoneType, we have previously not - // done this. Unfortunately, doing this will make the type None + const std::string& name = Var(lhs).name().name(); + // While it should in theory be possible to specialize + // the `x is None` to know x has type NoneType, we have previously + // not done this. Unfortunately, doing this will make the type None // propagate further in all loaded models. The handling of // unwrap_optional will fail in these cases since export did // not expect that the input would be none and an unannotated None. @@ -1154,7 +1161,7 @@ struct to_ir { // and (2) only enable this OPTIONAL_NONE when loading newer // graphs because it is incompatible with older graphs. // Refinement none(name, RefinementKind::OPTIONAL_NONE); - if (auto optional_type = lhs_value->type()->cast()) { + if (const auto optional_type = lhs_value->type()->cast()) { Refinement present(name, optional_type->getElementType()); if (tok == TK_IS) { return RefinementSet({}, {present}); @@ -1162,6 +1169,21 @@ struct to_ir { return RefinementSet({present}, {}); } } + if (const auto union_type = lhs_value->type()->cast()) { + std::vector to_subtract{NoneType::get()}; + c10::optional remaining = + union_type->subtractTypeSet(to_subtract); + std::vector all_present; + if (remaining) { + Refinement present{name, *remaining}; + all_present.push_back(std::move(present)); + } + if (tok == TK_IS) { + return RefinementSet({}, all_present); + } else { // TK_ISNOT + return RefinementSet(all_present, {}); + } + } return RefinementSet(); } @@ -1340,7 +1362,7 @@ struct to_ir { auto unified = unifyTypes( lt->getElementType(), out->type(), - /*default_to_any=*/true, + /*default_to_union=*/true, element_type_hint); if (lt->getElementType() != AnyType::get() && @@ -1458,7 +1480,7 @@ struct to_ir { c10::optional unified = unifyTypes( dt->getValueType(), v->type(), - /*default_to_any=*/true, + /*default_to_union=*/true, value_type_hint); // Warn the user if we inferred the type of the values to be `Any` @@ -1755,13 +1777,32 @@ struct to_ir { graph->createStore(x, fv)->insertBefore(false_block->return_node()); } - auto unified = unifyTypes(tv->type(), fv->type()); + SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x); + TypePtr full_type = nullptr; + if (maybe_sugared_x) { + Value* maybe_simple = asSimple(maybe_sugared_x); + if (maybe_simple) { + full_type = maybe_simple->type(); + } + } - // attempt to unify the types. we allow variables to be set to different - // types in each branch as long as that variable is not already in scope, - // or if that variable does not get used later. here, we save the error - // so that the error message will be more informative in the case that is - // used later. When a is accessed in (a + 1), the error will get printed + // Try to unify the types. If we found a type annotation earlier + // in the environment, and if that type annotation is some form + // of union, then we need to tell `unifyTypes` not to throw an + // error if the branched return types we found are heterogenous + bool default_to_union = full_type && + (full_type->kind() == UnionType::Kind || + full_type->kind() == OptionalType::Kind || + full_type->kind() == NumberType::Kind); + auto unified = unifyTypes( + tv->type(), fv->type(), /*default_to_union=*/default_to_union); + + // We allow variables to be set to different types in each branch + // as long as that variable is not already in scope or if that + // variable does not get used later. Here, we save the error so + // that the error message will be more informative in the case + // that is used later. When `a` is accessed in `(a + 1)`, the + // error will get printed: // if cond: // a = 1 // else: @@ -1799,76 +1840,146 @@ struct to_ir { } CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) { - // turn (float, (int, tuple)) into a flat list of types and type kind - // category checks: tuple_check = true, types = {float, int} - struct GatheredTypes { - GatheredTypes(ScriptTypeParser parser) : typeParser_(std::move(parser)) {} - void gather(const Expr& classinfo) { - if (classinfo.kind() == TK_TUPLE_LITERAL) { - for (Expr e : TupleLiteral(classinfo).inputs()) { - gather(e); - } - return; + Value* lhs_val = emitExpr(obj); + std::vector lhs_types; + std::vector rhs_types; + + std::function gather_rhs = [&](const Expr& expr) { + if (expr.kind() == TK_TUPLE_LITERAL) { + for (Expr e : TupleLiteral(expr).inputs()) { + gather_rhs(e); } - TypePtr type = typeParser_.parseTypeFromExpr(classinfo); - types.emplace_back(type); + return; } - bool staticallyTrue(const TypePtr& actual_type) { - // is this isinstance check statically true? - for (const TypePtr& typ : types) { - if (actual_type->isSubtypeOf(typ)) { - return true; - } + TypePtr type = typeParser_.parseTypeFromExpr(expr); + rhs_types.emplace_back(type); + }; + + lhs_types.push_back(lhs_val->type()); + gather_rhs(classinfo); + + standardizeVectorForUnion(&lhs_types); + standardizeVectorForUnion(&rhs_types); + + RefinementSet refinement; + + TypePtr unified_true = nullptr; + TypePtr unified_false = nullptr; + + std::vector isinstance_types; + std::vector not_isinstance_types; + + std::vector true_refinements; + std::vector false_refinements; + + bool all_lhs_subtype_some_rhs = true; + + // We can discard any rhs types that we know statically would be + // impossible. For example, if we had: + // + // def fn(x: Optional[str]): + // if isinstance(x, (List[str], str, int)): + // ... + // + // then `x` would be `str` in the true branch and `None` in the + // false branch, not `(List[str], str, int)` in the true branch + // and `None` in the false branch + for (const TypePtr& lhs_type : lhs_types) { + if (lhs_type == AnyType::get()) { + isinstance_types.insert( + isinstance_types.end(), rhs_types.begin(), rhs_types.end()); + not_isinstance_types.push_back(AnyType::get()); + // Edge case: we can still say that all lhs types subtype some + // rhs type if `lhs` is `Any` and `rhs` is `Any` + if (isinstance_types.size() != 1 || + isinstance_types[0] != AnyType::get()) { + all_lhs_subtype_some_rhs = false; } - return false; + break; } - bool maybeOfKind(TypeKind kind, const TypePtr& actual_type) { - if (actual_type->kind() == AnyType::Kind) { - return true; + + auto get_smaller_type = [&](TypePtr t1, TypePtr t2) -> TypePtr { + if (t1->isSubtypeOf(t2)) { + return t1; + } else if (t2->isSubtypeOf(t1)) { + return t2; + } else { + return nullptr; } - if (auto op = actual_type->cast()) { - return op->getElementType()->kind() == kind; + }; + + TypePtr found_refinement = nullptr; + for (const TypePtr& rhs_type : rhs_types) { + TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type); + if (!maybe_smaller_type) { + continue; + } else if (*maybe_smaller_type == *lhs_type) { + // Cover the case that we have something like + // lhs = `List[str]` and rhs = `list` + found_refinement = lhs_type; + } else if (*maybe_smaller_type == *rhs_type) { + // We want the narrowest possible type + found_refinement = found_refinement + ? *(unifyTypes(found_refinement, rhs_type)) + : rhs_type; } - return false; } - bool staticallyFalse(const TypePtr& actual_type) { - for (const TypePtr& typ : types) { - if (typ->isSubtypeOf(actual_type)) { - return false; - } - if ((typ->isSubtypeOf(AnyListType::get()) && - maybeOfKind(ListType::Kind, actual_type)) || - (typ->isSubtypeOf(AnyTupleType::get()) && - maybeOfKind(TupleType::Kind, actual_type))) { - return false; - } + + if (found_refinement) { + if (*found_refinement == *lhs_type) { + all_lhs_subtype_some_rhs &= true; } - return true; + isinstance_types.push_back(found_refinement); + } else { + // If the lhs couldn't be a subtype of the rhs (or couldn't + // be "refined" to itself, as in the `List[str]` and `list` + // case above), then we add `lhs_type` to the false branch + // refinements. This is because the type can still be itself + // if the `isinstance` check is false + not_isinstance_types.push_back(lhs_type); + all_lhs_subtype_some_rhs = false; } - ScriptTypeParser typeParser_; - std::vector types; - }; - GatheredTypes gathered(typeParser_); - gathered.gather(classinfo); - auto val = emitExpr(obj); - RefinementSet refinement; - if (gathered.types.size() == 1 && - gathered.types.at(0)->isSubtypeOf(val->type()) && - obj.kind() == TK_VAR) { + } + + // For use with `unifyTypeList` + std::stringstream nowhere; + + // Get a single type for the true and false branches + if (!isinstance_types.empty()) { + unified_true = + *unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true); + } + if (obj.kind() == TK_VAR && unified_true) { + std::string ident = Var(obj).name().name(); + true_refinements = {Refinement(ident, unified_true)}; + } + + // Get a single type for the true and false branches + if (!not_isinstance_types.empty()) { + unified_false = *unifyTypeList( + not_isinstance_types, nowhere, /*default_to_union=*/true); + } + if (obj.kind() == TK_VAR && unified_false) { std::string ident = Var(obj).name().name(); - Refinement isinstance(std::move(ident), gathered.types.at(0)); - refinement = RefinementSet({isinstance}, {}); + false_refinements = {Refinement(ident, unified_false)}; } - if (gathered.staticallyTrue(val->type())) { + refinement = RefinementSet(true_refinements, false_refinements); + + bool is_statically_false = isinstance_types.empty(); + + // If the statement is statically true + if (all_lhs_subtype_some_rhs) { return CondValue(*graph, obj.range(), true, std::move(refinement)); } - if (gathered.staticallyFalse(val->type())) { + + if (is_statically_false) { return CondValue(*graph, obj.range(), false, std::move(refinement)); } + // check maybe true/false at runtime, need an actual op Value* result = - graph->insertNode(graph->createIsInstance(val, gathered.types)) + graph->insertNode(graph->createIsInstance(lhs_val, rhs_types)) ->output(); return CondValue(result, std::move(refinement), c10::nullopt); } @@ -2124,6 +2235,7 @@ struct to_ir { } // emit assserions as an if branch so that assertions will reuse the + // message void emitAssert(const Assert& stmt) { CondValue cond_value = emitCondExpr(stmt.test()); List true_branch = List::create(stmt.range(), {}); @@ -2979,7 +3091,9 @@ struct to_ir { // after annotation so that variables assigned to this None will still // get the right type. To do this, we make a None constant that // has the type Optional[T] - if (type->kind() == OptionalType::Kind && + if ((type->kind() == OptionalType::Kind || + (type->kind() == UnionType::Kind && + type->expect()->canHoldType(NoneType::get()))) && expr->type()->isSubtypeOf(NoneType::get())) { Node* none = graph->createNone(); none->output()->setType(type); @@ -3435,8 +3549,9 @@ struct to_ir { size_t n_binders, const TypePtr& type_hint = nullptr) { switch (tree.kind()) { - case TK_VAR: + case TK_VAR: { return environment_stack->getSugaredVar(Var(tree).name()); + } case '.': { auto select = Select(tree); auto sv = emitSugaredExpr(select.value(), 1); @@ -3710,7 +3825,7 @@ struct to_ir { type_hint ? type_hint->expect()->getElementType() : nullptr; c10::optional unified = unifyTypeList( - types, nowhere, /*default_to_any=*/true, element_type_hint); + types, nowhere, /*default_to_union=*/true, element_type_hint); if (!type_hint && *unified == AnyType::get()) { TORCH_WARN( @@ -3881,7 +3996,7 @@ struct to_ir { c10::optional unified = unifyTypeList( types, /*why_not=*/nowhere, - /*default_to_any=*/true, + /*default_to_union=*/true, value_type_hint); if (!type_hint && *unified == AnyType::get()) { diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 6b434882eb798..fb6d1ab7f92e5 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -8,9 +8,10 @@ namespace torch { namespace jit { -// try to match a list of inputs and keyword 'attributes' to this schema, -// if it works return the flat list of positional inputs to the call -// if it returns nullopt, then failure_messages contains a good error report +// Try to match a list of inputs and keyword 'attributes' to this +// schema. Return the flat list of positional inputs to the call or +// `c10::nullopt` on failure (`failure_messages` contains a good error +// report in this case) struct MatchedSchema { std::vector inputs; diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index db1a1e83bc7ce..a543b5b6fbe5d 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -32,6 +32,7 @@ using c10::StringType; using c10::Symbol; using c10::TensorType; using c10::TupleType; +using c10::UnionType; using c10::VarType; namespace torch { @@ -235,7 +236,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - size_t stride = c10::stoi(num, &num_len); + auto stride = c10::stoll(num, &num_len); strides.push_back(stride); }); return; @@ -260,7 +261,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { const std::string& num = L.expect(TK_NUMBER).text(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::string::size_type num_len; - size_t dim = c10::stoi(num, &num_len); + auto dim = c10::stoll(num, &num_len); dims.emplace_back(dim); }); if (seen_strides) { @@ -331,6 +332,18 @@ std::pair> SchemaTypeParser::parseType() { L.expect(')'); alias_info = parseAliasAnnotation(); value = DictType::create(key_type, value_type); + } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { + L.next(); + L.expect('('); + std::vector types; + types.emplace_back(parseType().first); + while (L.cur().kind != ')') { + L.expect(','); + types.emplace_back(parseType().first); + } + L.expect(')'); + alias_info = parseAliasAnnotation(); + value = UnionType::create(types); } else if ( complete_tensor_types && L.cur().kind == TK_IDENT && parseTensorDType(L.cur().text())) { diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index eac51ab527d52..bafe5188cc4eb 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -42,7 +42,7 @@ TypePtr ScriptTypeParser::subscriptToType( } std::vector subscript_expr_types; for (auto expr : subscript.subscript_exprs()) { - subscript_expr_types.push_back(parseTypeFromExprImpl(expr)); + subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr)); } return TupleType::create(subscript_expr_types); } else if (typeName == "List" || typeName == "list") { @@ -65,6 +65,13 @@ TypePtr ScriptTypeParser::subscriptToType( parseTypeFromExprImpl(*subscript.subscript_exprs().begin()); return OptionalType::create(elem_type); + } else if (typeName == "Union") { + std::vector subscript_expr_types; + subscript_expr_types.reserve(subscript.subscript_exprs().size()); + for (auto expr : subscript.subscript_exprs()) { + subscript_expr_types.emplace_back(parseTypeFromExprImpl(expr)); + } + return UnionType::create(subscript_expr_types); } else if (typeName == "Future" || typeName == "torch.jit.Future") { if (subscript.subscript_exprs().size() != 1) { throw ErrorReport(subscript) @@ -83,30 +90,6 @@ TypePtr ScriptTypeParser::subscriptToType( auto elem_type = parseTypeFromExprImpl(*subscript.subscript_exprs().begin()); return RRefType::create(elem_type); - } else if (typeName == "Union") { - // In Python 3.9+, Union[NoneType, T] or Union[T, NoneType] are - // treated as Optional[T]. Adding the same support for Union in Torchscript. - const char* const err = - "General Union types are not currently supported." - " Only Union[T, NoneType] (i.e. Optional[T]) is " - "supported."; - if (subscript.subscript_exprs().size() != 2) { - throw ErrorReport(subscript) << (err); - } - auto first_type = parseTypeFromExprImpl(subscript.subscript_exprs()[0]); - auto second_type = parseTypeFromExprImpl(subscript.subscript_exprs()[1]); - - bool first_none = first_type == NoneType::get(); - bool second_none = second_type == NoneType::get(); - - if (first_none && !second_none) { - return OptionalType::create(second_type); - } else if (!first_none && second_none) { - return OptionalType::create(first_type); - } else { - throw ErrorReport(subscript.range()) << err; - } - } else if (typeName == "Dict" || typeName == "dict") { if (subscript.subscript_exprs().size() != 2) { throw ErrorReport(subscript) diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index ab70d6c6f326a..a5f000769badc 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -119,7 +119,7 @@ std::shared_ptr SimpleValue::attr( {"layout", "prim"}, {"T", "prim"}, {"ndim", "prim"}, {"name", "prim"}, {"real", "aten"}, {"imag", "aten"}, - {"retains_grad", "aten"}, + {"retains_grad", "aten"}, {"is_ort", "prim"}, }}, {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}}; auto kind = value_->type()->kind(); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 18512b4617d6c..03afbdd3508b2 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -13,94 +13,139 @@ namespace jit { namespace { -// For any mutable type, map it to a type such that all other types which it can -// alias will be mapped to the same type. This function follows a similar logic -// to `unifyTypes` because any two mutable types which can be unified -// can alias each other. -// getMutableTypePtr(Optional[List[int]]) == getMutableTypePtr([List[int]]) -// If a type is not mutable, return nullopt -// This class helps convert types to their mutable equivalent by looking up -// cached conversions. +TypePtr toSingleType(AliasTypeSet& mut_types) { + return mut_types.size() == 1 ? mut_types[0] + : c10::UnionType::create(mut_types); +} + +// This class determines whether a type is mutable, and, if so, it maps +// the type to its "mutable equivalent" (see definition in +// `mapTypeToAliasTypeSet`). It uses a cache of TypePtrs to speed up these +// type lookups class MutableTypePtrHelper { public: explicit MutableTypePtrHelper( - std::unordered_map* mutable_type_cache) + std::unordered_map* mutable_type_cache) : mutable_type_cache_(mutable_type_cache) {} - c10::optional getMutableType(const TypePtr& type) { + // Map any mutable type to a type such that all other types which the + // mutable type can alias will be mapped to the same type. For + // example, calling this method on `Optional[List[int]]` should be + // the same as calling this method on `List[int]`. + // + // Rules: + // - If the type is not mutable, return `nullopt` + // - If the type is a `Tuple`, that means that it's an immutable + // object that can itself contain mutable objects. We want to make + // sure that the mutable objects are correctly aliased, so we + // remove the immutable objects. (For example, + // `Tuple[int, Tensor]` would become `Tuple[Tensor]`, while + // `Tuple[int, str]` would be returned as `nullopt`.) This is a + // convenience that makes it easy to check if the `Tuple` + // contains only immutable objects, though it's not technically + // necessary + // - For any Tensor type (including Tensor types that are part of + // a larger container, e.g. `List[Tensor]`), return the + // "unshaped" version of that Tensor. An "unshaped" Tensor is a + // Tensor with shape information removed. For example, a Tensor + // of dimension 4 would map to the same type as a Tensor of + // dimension 1. This allows us to treat all subclasses of Tensor + // as a single, homogenous "Tensor" type. + c10::optional mapTypeToAliasTypeSet(const TypePtr& type) { if (mutable_type_cache_) { - auto maybe_type = mutable_type_cache_->find(type); - if (maybe_type != mutable_type_cache_->end()) { - return maybe_type->second; + auto maybe_type_mapping = mutable_type_cache_->find(type); + if (maybe_type_mapping != mutable_type_cache_->end()) { + return maybe_type_mapping->second; } } - auto mutable_type = getMutableTypeImpl(type); - if (mutable_type_cache_ && mutable_type) { - mutable_type_cache_->emplace(type, *mutable_type); + auto mutable_types = mapTypeToAliasTypeSetImpl(type); + if (mutable_type_cache_ && mutable_types) { + mutable_type_cache_->emplace(type, *mutable_types); } - return mutable_type; + return mutable_types; } private: - c10::optional getMutableTypeImpl(const TypePtr& type) { + c10::optional mapTypeToAliasTypeSetImpl(const TypePtr& type) { switch (type->kind()) { case TypeKind::ListType: case TypeKind::DictType: case TypeKind::ClassType: case TypeKind::TensorType: - // TODO: lookup cached contained types. this is kind of tricky - // because a List[Optional[T]] should still be - // List[Optional[Unshaped(T)]], however the getMutableType(Optional[T]) - // == T - return unshapedType(type); - case TypeKind::OptionalType: - return getMutableType(type->castRaw()->getElementType()); + // TODO: Look up cached contained types. this is kind of tricky + // because a `List[Optional[T]]` should still be + // `List[Optional[Unshaped(T)]]`, but + // `mapTypeToAliasTypeSet(Optional[T])` should be `T` + return AliasTypeSet{unshapedType(type)}; + case TypeKind::UnionType: { + AliasTypeSet mutable_types; + for (const TypePtr& inner : + type->expect()->containedTypes()) { + if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) { + mutable_types.insert( + mutable_types.end(), + (*maybe_inner_types).begin(), + (*maybe_inner_types).end()); + } + } + if (mutable_types.size() == 0) { + return c10::nullopt; + } + return mutable_types; + } + case TypeKind::OptionalType: { + auto inner = type->castRaw()->getElementType(); + return mapTypeToAliasTypeSet(inner); + } case TypeKind::AnyType: - return type; + return {AliasTypeSet{type}}; case TypeKind::FutureType: { - if (auto elem = - getMutableType(type->castRaw()->getElementType())) { - return FutureType::create(*elem); + if (auto maybe_mut_types = mapTypeToAliasTypeSet( + type->castRaw()->getElementType())) { + auto mut_type = toSingleType(*maybe_mut_types); + return {AliasTypeSet{FutureType::create(mut_type)}}; } return c10::nullopt; } case TypeKind::TupleType: { std::vector mutable_types; - for (const auto& elem : type->expectRef().elements()) { - if (auto mut_elem = getMutableType(elem)) { - mutable_types.push_back(*mut_elem); + for (const TypePtr& inner : type->expectRef().elements()) { + if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) { + mutable_types.insert( + mutable_types.end(), + (*maybe_inner_types).begin(), + (*maybe_inner_types).end()); } } if (mutable_types.size() == 0) { return c10::nullopt; - } else { - return TupleType::create(mutable_types); } + return {AliasTypeSet{TupleType::create(mutable_types)}}; } default: return c10::nullopt; } } - std::unordered_map* mutable_type_cache_; + std::unordered_map* mutable_type_cache_; }; bool isMutableTypeImpl( const TypePtr& type, - std::unordered_map* mutable_type_cache) { - // check common cases to avoid recursively constructing type in - // getMutableTypePtrImpl + std::unordered_map* mutable_type_cache) { + // Check common cases to avoid recursively constructing type in + // `mapTypeToAliasTypeSetPtrImpl` auto kind = type->kind(); if (kind == TypeKind::TensorType || kind == TypeKind::ListType || kind == TypeKind::ClassType || kind == TypeKind::DictType) { return true; } MutableTypePtrHelper helper(mutable_type_cache); - return helper.getMutableType(type) != c10::nullopt; + return helper.mapTypeToAliasTypeSet(type) != c10::nullopt; } } // namespace -// static isMutableType does not use cache of type -> mutable type equivalent +// Static `isMutableType` does not use cache of type -> mutable type equivalent bool AliasDb::isMutableType(const TypePtr& type) { return isMutableTypeImpl(type, nullptr); } @@ -109,7 +154,7 @@ bool AliasDb::isMutableType(const Value* v) { return isMutableType(v->type()); } -// makes use of type -> mutable cache +// Make use of type -> mutable cache bool AliasDb::isMutableTypeInternal(const TypePtr& type) const { return isMutableTypeImpl(type, &mapped_mutable_types_); } @@ -118,21 +163,17 @@ bool AliasDb::isMutableTypeInternal(const Value* v) const { return isMutableTypeInternal(v->type()); } -c10::optional AliasDb::getMutableTypePtr(const TypePtr& type) const { +c10::optional AliasDb::mapTypeToAliasTypeSetPtr( + const TypePtr& type) const { MutableTypePtrHelper helper(&mapped_mutable_types_); - return helper.getMutableType(type); -} - -bool AliasDb::isContainerType(const TypePtr& type) const { - auto mut_type = getMutableTypePtr(type); - return mut_type && (*mut_type)->containedTypes().size() > 0; + return helper.mapTypeToAliasTypeSet(type); } AliasDb::~AliasDb() = default; -// Structure used during analysis to keeps track of all writes at a high level. -// When analysis is completed this will be used to construct a more efficient -// WriteIndex. +// Structure used during analysis to keep track of all writes at a high +// level. When the analysis is completed, this will be used to construct +// a more efficient WriteIndex struct AliasDb::WriteRegistry { void registerWrite(const Value* v, Node* n) { writes_[n].emplace_back(v); @@ -170,7 +211,7 @@ AliasDb::AliasDb(std::shared_ptr graph, bool isFrozen) writeIndex_ = TWriteIndex(); auto& writeIndex = *writeIndex_; // to make operator[] less ugly - // build the write index + // Build the write index for (const auto& write : writeRegistry_->writes_) { Node* node = write.first; const std::vector writtenValues = write.second; @@ -207,7 +248,7 @@ AliasDb::AliasDb(std::shared_ptr graph, bool isFrozen) // out of sync (since we have no way of registering new writes) writeRegistry_ = nullptr; - // initialize the write cache + // Initialize the write cache buildWrittenToLocationsIndex(); GRAPH_DEBUG(toString()); } @@ -324,10 +365,10 @@ MemoryLocations AliasDb::getReads(Node* n) const { std::string AliasDb::getElementName(const Element* e) const { if (e->values.empty()) { - // not the most efficient way, but given the fact there are + // Not the most efficient way, but given the fact there are // not too many types and even fewer of them will end up in - // wildcardIndex_, we should be fine with a linear search - // each time we hit a wildcard leaf + // `wildcardIndex_`, we should be fine with a linear search + // each time we hit a Wildcard leaf for (const auto& ent : wildcardIndex_) { if (ent.second == e) { return std::string("WILDCARD for type ") + ent.first->str(); @@ -362,17 +403,27 @@ std::string AliasDb::toString() const { ss << "\n===2. ALIAS DB===\n"; for (const auto& ptrPair : elementMap_) { const auto element = ptrPair.second; + int ct = 0; if (!element->pointsTo.empty()) { ss << getElementName(element) << " points to: "; for (const auto pointedTo : element->pointsTo) { - ss << getElementName(memoryDAG_->fromIndex(pointedTo)) << ", "; + if (ct > 0) { + ss << ", "; + } + ++ct; + ss << getElementName(memoryDAG_->fromIndex(pointedTo)); } ss << "\n"; } + ct = 0; if (!element->containedElements.empty()) { ss << getElementName(element) << " contains: "; for (const auto contained : element->containedElements) { - ss << getElementName(memoryDAG_->fromIndex(contained)) << ", "; + ss << getElementName(memoryDAG_->fromIndex(contained)); + if (ct > 0) { + ss << ", "; + } + ++ct; } ss << "\n"; } @@ -839,8 +890,7 @@ void AliasDb::analyzeLoop(Node* node) { TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size()); // Run alias analysis on the loop body, iterating until the block output - // alias info converges. - // Copy node input aliases to block input + // alias info converges. Copy node input aliases to block input mapAliases(blockInputs, loopCarriedInputs); // Populate block output alias info by analyzing the body @@ -996,7 +1046,7 @@ bool AliasDb::functionalNonEscapingListUse(const Use& use) const { return false; } -// List or dict or tuple: construct: create an aliasing element for the actual +// List or dict or tuple construct: create an aliasing element for the actual // container, then mark all inputs as wildcards, since they've gone inside the // container. Then, add the wildcard sets of appropriate type to the contained // elements of the container. @@ -1073,52 +1123,50 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) { return; } - // the contained types of immutable type containers (optional, tuple, future) - // are unified, so these types can be mutable or immutable - // and point to a type which is mutable or immutable. - // Any is mutable but can point to an immutable type through refinement + // The contained types of immutable type containers (`Optional`, + // `Tuple`, `Future`, and `Union`) are unified, so these types can be + // mutable or immutable and point to a type which is mutable or + // immutable. `Any` is mutable but can point to an immutable type + // through refinement if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) { bool expected_kind = false; for (auto kind : {from->type()->kind(), to->type()->kind()}) { expected_kind = expected_kind || (kind == TypeKind::OptionalType || kind == TypeKind::FutureType || - kind == TypeKind::TupleType) // immutable type containers + kind == TypeKind::TupleType || + kind == TypeKind::UnionType) // immutable type containers || kind == TypeKind::AnyType; } TORCH_INTERNAL_ASSERT( expected_kind, from->type()->str(), to->type()->str()); return; } - // both immutable if (!isMutableTypeInternal(from)) { return; } - if (from == to) { return; } - // At this point, we are dealing with two mutable types. - auto fromEl = getOrCreateElement(from); - auto toEl = getOrCreateElement(to); + // At this point, we are dealing with two mutable types + auto from_el = getOrCreateElement(from); + auto to_el = getOrCreateElement(to); - memoryDAGBuilder_->makePointerTo(fromEl, toEl); + memoryDAGBuilder_->makePointerTo(from_el, to_el); } void AliasDb::addToContainedElements( - const Value* elem, + const Value* inner, const Value* container) { - if (!isMutableTypeInternal(elem)) { + if (!isMutableTypeInternal(inner)) { return; } - TORCH_INTERNAL_ASSERT(isContainerType(container->type())); - - auto elemEl = getOrCreateElement(elem); - auto contEl = getOrCreateElement(container); + auto inner_el = getOrCreateElement(inner); + auto cont_el = getOrCreateElement(container); - memoryDAGBuilder_->addToContainedElements(elemEl, contEl); + memoryDAGBuilder_->addToContainedElements(inner_el, cont_el); } bool AliasDb::mayAlias(const Value* a, const Value* b) const { @@ -1203,8 +1251,8 @@ void AliasDb::createValue(const Value* value) { void AliasDb::giveFreshAlias( const Value* value, bool add_wildcard_to_contained_elems) { - auto maybe_mut_type = getMutableTypePtr(value->type()); - if (!maybe_mut_type) { + auto maybe_mut_types = mapTypeToAliasTypeSetPtr(value->type()); + if (!maybe_mut_types) { return; } @@ -1217,7 +1265,11 @@ void AliasDb::giveFreshAlias( auto new_elem = memoryDAGBuilder_->makeFreshValue(value); elementMap_[value] = new_elem; if (add_wildcard_to_contained_elems) { - addContainedTypesToFreshElement(new_elem, *maybe_mut_type); + if ((*maybe_mut_types).size() > 1) { + pointUnionTypeElementToAllContainedTypes(new_elem, *maybe_mut_types); + } else { + addContainedTypesToFreshElement(new_elem, *maybe_mut_types); + } } } @@ -1639,29 +1691,47 @@ bool AliasDb::mayAliasWildcard(const at::ArrayRef vs) const { } c10::optional AliasDb::tryGetOrCreateWildcard(const TypePtr& type) { - auto updated_type = getMutableTypePtr(type); - if (!updated_type) { + auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type); + if (!maybe_mut_types) { return c10::nullopt; } - auto mapped_type = *updated_type; - auto existing_wildcard = wildcardIndex_.find(mapped_type); + auto mut_type = toSingleType(*maybe_mut_types); + auto existing_wildcard = wildcardIndex_.find(mut_type); if (existing_wildcard != wildcardIndex_.end()) { return existing_wildcard->second; } auto wildcard_elem = memoryDAGBuilder_->makeFreshValue(nullptr); - wildcardIndex_.emplace(mapped_type, wildcard_elem); - addContainedTypesToFreshElement(wildcard_elem, mapped_type); + wildcardIndex_.emplace(mut_type, wildcard_elem); + if ((*maybe_mut_types).size() > 1) { + pointUnionTypeElementToAllContainedTypes(wildcard_elem, *maybe_mut_types); + } else { + addContainedTypesToFreshElement(wildcard_elem, *maybe_mut_types); + } return wildcard_elem; } -void AliasDb::addContainedTypesToFreshElement( +void AliasDb::pointUnionTypeElementToAllContainedTypes( Element* container_elem, - const TypePtr& mut_type) { - for (const auto& contained : mut_type->containedTypes()) { - auto maybe_elem = tryGetOrCreateWildcard(contained); + const AliasTypeSet& mut_types) { + for (const auto& mut_type : mut_types) { + auto maybe_elem = tryGetOrCreateWildcard(mut_type); if (maybe_elem) { - memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem); + TORCH_INTERNAL_ASSERT(*maybe_elem != container_elem); + memoryDAGBuilder_->makePointerTo(container_elem, *maybe_elem); + } + } +} + +void AliasDb::addContainedTypesToFreshElement( + Element* container_elem, + const AliasTypeSet& mut_types) { + for (const auto& mut_type : mut_types) { + for (const auto& contained : mut_type->containedTypes()) { + auto maybe_elem = tryGetOrCreateWildcard(contained); + if (maybe_elem) { + memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem); + } } } } @@ -1669,26 +1739,38 @@ void AliasDb::addContainedTypesToFreshElement( // Search the wildcard index for an element that corresponds to the given type. // Const version returns nullptr Element* AliasDb::getWildcard(const TypePtr& type) const { - auto maybe_mut_type = getMutableTypePtr(type); - if (!maybe_mut_type) { - return nullptr; - } - TypePtr mut_type = *maybe_mut_type; - auto wildcard = wildcardIndex_.find(mut_type); - if (wildcard != wildcardIndex_.end()) { - return wildcard->second; + auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type); + if (!maybe_mut_types) { + return {}; + } + if ((*maybe_mut_types).size() > 1) { + auto union_type = UnionType::create(*maybe_mut_types); + // Get a pair where the TypePtr is this Union + // type and the Element is the corresponding Wildcard + auto maybe_union_pair = wildcardIndex_.find(union_type); + if (maybe_union_pair != wildcardIndex_.end()) { + return (*maybe_union_pair).second; + } + } else { + // Get a pair where the TypePtr is the given + // type and the Element is the corresponding Wildcard + auto type_pair = wildcardIndex_.find((*maybe_mut_types)[0]); + if (type_pair != wildcardIndex_.end()) { + return type_pair->second; + } } - return nullptr; + return {}; } // Register `v` as a wildcard value. c10::optional AliasDb::setWildcard(const Value* v) { - auto maybe_wildcardElement = tryGetOrCreateWildcard(v->type()); + c10::optional maybe_wildcardElement = + tryGetOrCreateWildcard(v->type()); if (!maybe_wildcardElement) { return c10::nullopt; } - // Ensure that we create a corresponding element for `v` still, as it is an - // invariant that all mutable values have an element. + // Ensure that we create a corresponding Element for `v` still, as it is an + // invariant that all mutable values have an Element getOrCreateElement(v); wildcards_.insert(v); return *maybe_wildcardElement; diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index cd888ade69291..7feb2b9938d8b 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -34,6 +34,12 @@ namespace jit { * Values that contain other mutable types, such as List[Tensor], are * initialized as containing the Wildcard set for all contained mutable types. * + * The AliasDb API references the idea of "mutable" vs "immutable" + * types. "Mutable" means that the object's value can change, while + * "immutable" means that the value is fixed. (For example, `List` is + * mutable, so you can add and delete elements from it. On the other + * hand, you can't modify a Tuple once you create it, making `Tuple` an + * immutable container.) */ class AliasDb { public: @@ -95,7 +101,7 @@ class AliasDb { const at::ArrayRef& a, const at::ArrayRef& b) const; - // Move 'n' (already in the graph) after 'movePoint' in the topological order. + // Move `n` (already in the graph) after `movePoint` in the topological order. // // Tries to preserve value dependencies, so other nodes might be moved. We // make two guarantees about the postcondition of the node list: @@ -125,6 +131,10 @@ class AliasDb { TORCH_API bool dumpToGraphvizFile(const char* filename) const; TORCH_API std::string toGraphviz() const; + // Returns `true` if the given element is mutable or if it is a + // container type with an internal mutable element (e.g. + // `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so + // it would be considered a "mutable type" in AliasDb) static bool isMutableType(const Value* v); static bool isMutableType(const TypePtr& type); @@ -181,7 +191,7 @@ class AliasDb { // Register `v` as a wildcard value. c10::optional setWildcard(const Value* v); - // Is this a value which will not alias + // Is this a value which will not alias? bool nonAliasingValue(const Value* elem) const; /** @@ -221,11 +231,10 @@ class AliasDb { bool add_wildcard_to_contained_elems = true); Element* getOrCreateElement(const Value* value); - c10::optional getMutableTypePtr(const TypePtr& type) const; + c10::optional mapTypeToAliasTypeSetPtr( + const TypePtr& type) const; bool functionalNonEscapingListUse(const Use& use) const; - bool isContainerType(const TypePtr& type) const; - std::shared_ptr graph_; // If the Module is frozen then consider attributes as freshly created @@ -239,21 +248,24 @@ class AliasDb { // Mapping of values to MemoryDAG elements ska::flat_hash_map elementMap_; - // All wildcard elements (one for each unique mutable type). + // All wildcard Elements (one for each unique mutable type) std::unordered_map wildcardIndex_; Element* getWildcard(const TypePtr& type) const; c10::optional tryGetOrCreateWildcard(const TypePtr& type); void addContainedTypesToFreshElement( Element* container_elem, - const TypePtr& mut_type); + const AliasTypeSet& mut_types); + void pointUnionTypeElementToAllContainedTypes( + Element* container_elem, + const AliasTypeSet& mut_types); std::vector getElements(at::ArrayRef vs) const; bool mayAliasWildcard(const Value* v) const; bool mayAliasWildcard(const at::ArrayRef vs) const; bool hasWriters(const at::ArrayRef& values) const; - // cached mapping of type ptrs to their mutable types - mutable std::unordered_map mapped_mutable_types_; + // Cached mapping of type ptrs to their mutable types + mutable std::unordered_map mapped_mutable_types_; /** * State for tracking write info. diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 05ce8d40ea7c5..e62ef93b57379 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -511,7 +511,7 @@ void Graph::lint() const { // - Params and return do NOT occur in nodes // - next_unique_ is greater than all uniques in graph // - uniques in all_nodes are unique - // - every use will occur later in the topsort + // - every use will occur later in the toposort struct LintScope { LintScope() = default; @@ -787,7 +787,9 @@ bool Value::mustBeNone() const { } bool Value::mustNotBeNone() const { return node_->kind() != prim::AutogradAdd && type() != NoneType::get() && - !type()->cast(); + !type()->cast() && + !(type()->cast() && + type()->expect()->canHoldType(NoneType::get())); } std::string Value::debugNameBase() const { @@ -1765,20 +1767,23 @@ Node* Graph::createEnumValue(Value* e) { return n; } -Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef values) { +Node* Graph::createList( + const TypePtr& contained_type, + at::ArrayRef values) { auto n = create(prim::ListConstruct, values); for (const auto& v : values) { TORCH_CHECK( - v->type()->isSubtypeOf(elem_type), + v->type()->isSubtypeOf(contained_type), "Expected a list element that subtypes '", - elem_type->repr_str(), + contained_type->repr_str(), "' but got an element of type '", v->type()->repr_str(), "'"); } - n->output()->setType(ListType::create(elem_type)); + n->output()->setType(ListType::create(contained_type)); return n; } + Node* Graph::createListUnpack(Value* v, size_t size) { ListTypePtr list_type = v->type()->expect(); TypePtr elem_type = list_type->getElementType(); diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index dee222bd480df..99f6a6ce5c57b 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -84,7 +84,7 @@ using namespace ::c10::cuda; struct Function; struct MatchedSchema; -// Graph represents one "function" of computation. +// A Graph represents one "function" of computation. // It uses a simple ownership model where the graph owns all the nodes inside // it. All references inside the graph are raw pointers. Destroying the Graph // will invalidate any pointers to nodes in the graph. @@ -104,9 +104,9 @@ TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n); // A list of nodes, with inputs and outputs struct Block; -// Each use is represented by this type, see Node::uses() -// 'user' is the consumer of the value, offset is the index into -// 'user's input this where the produces will be found. +// Each use is represented by this type, see 'Node::uses()' +// 'user' is the consumer of the value, 'offset' is the index into +// 'user's input this where the producers will be found. struct Use { Use(Node* user, size_t offset) : user(user), offset(offset) {} Node* user; @@ -338,14 +338,16 @@ struct TORCH_API Node { protected: Node(Graph* graph_, NodeKind kind_); // defined after graph public: - // each node but Return/Param - // is associated with exactly one place in the node list... - // of the graph_ - // this circular is a doubly-linked list, the Return node is used as the - // sentinel for the beginning and end of the list such that the list never has - // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev - // pointer using an array to allow the same iterator class for forward and - // reverse node lists This list represents a topological sort + // Each Node but Return/Param Nodes are associated with exactly one + // place in the Node list of the Graph. The Graph itself is a circular + // doubly-linked list. The Return Node is used as the sentinel for the + // "beginning"/"end" of the list. This means that you can tell when + // you've traversed the entire list without means worrying about null + // pointers. `next_in_graph[0]` is the pointer to the next Node, while + // `next_in_graph[1]` is the pointer to the previous Node. The + // linked list is implemented as an array to allow the same iterator + // class for forward and reversed Node lists. Taken together, this + // list also represents a topological sort of the Nodes in the Graph. // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays) Node* next_in_graph[2] = {nullptr, nullptr}; @@ -980,7 +982,6 @@ struct TORCH_API Node { // subclasses should extend if they have additional information to copy. // 'this' will be allocated with s->allocNewInstance(g) so it should have // the same concrete type as 's' - // virtual void cloneFrom(Node* s); }; @@ -1247,7 +1248,7 @@ struct Graph { TORCH_API Node* createEnumName(Value* e); TORCH_API Node* createEnumValue(Value* e); TORCH_API Node* createList( - const TypePtr& elem_type, + const TypePtr& contained_type, at::ArrayRef values); TORCH_API Node* createListUnpack(Value* v, size_t size); TORCH_API Node* createDict( diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index 9c734f40a25a2..a75ffe16c61f5 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -13,6 +13,12 @@ namespace jit { namespace { +C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage( + const std::string& debug_handles_string) { + return "Debug info for handle(s): " + debug_handles_string + + ", was not found."; +} + std::pair, std::string> getStackTraceWithModuleHierarchy( const DebugInfoTuple& source_callstack, const std::string& caller_name) { @@ -49,11 +55,7 @@ std::pair, std::string> getStackTraceWithModuleHierarchy // Now add source range info to stack entries.emplace_back( StackEntry{prev_function_name, callstack_ptr->source_range()}); - if (callstack_ptr->function()) { - prev_function_name = callstack_ptr->function()->name(); - } else { - prev_function_name = callstack_ptr->function_name(); - } + prev_function_name = callstack_ptr->function_name(); // Function name appended here // It is renamed to prev_function_name because for StackEntry // it will be appended in the next iteration. This is the format @@ -156,8 +158,7 @@ std::string MobileDebugTable::getModuleHierarchyInfo( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "Module info for handle, " + std::to_string(debug_handle) + - ", not found."; + return debugHandlesNotFoundMessage(std::to_string(debug_handle)); } return (getStackTraceWithModuleHierarchy( {it->second}, "top", top_module_type_name)) @@ -176,8 +177,7 @@ std::string MobileDebugTable::getSourceDebugString( const std::string& top_module_type_name) const { const auto it = callstack_ptr_map_.find(debug_handle); if (it == callstack_ptr_map_.end()) { - return "Debug info for handle, " + std::to_string(debug_handle) + - ", not found."; + return debugHandlesNotFoundMessage(std::to_string(debug_handle)); } return (getStackTraceWithModuleHierarchy( {it->second}, "top", top_module_type_name)) @@ -212,8 +212,7 @@ std::pair MobileDebugTable:: debug_handles_string += std::to_string(debug_handle); } debug_handles_string += "}"; - debug_handles_string = - "Debug info for handles: " + debug_handles_string + ", was not found."; + debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string); return {debug_handles_string, debug_handles_string}; } return (getStackTraceWithModuleHierarchy( diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 0775a550d2a79..fad8c39bd1f4d 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -67,7 +67,7 @@ bool Function::append_operator( auto jit_op = findOperatorFor(opname); std::vector args; if (jit_op) { - fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); }; + fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); }; args = jit_op->schema().arguments(); } else { auto op = c10::Dispatcher::singleton().findSchema(opname_c10); @@ -99,21 +99,35 @@ bool Function::append_operator( // from model. We can use it to handle backward compatibility. if (num_specified_args && num_specified_args.value() < static_cast(args.size())) { - // Sanity check at load time, to save perf at runtime - for (size_t i = num_specified_args.value(); i < args.size(); ++i) { - auto default_val = args[i].default_value(); - TORCH_CHECK( - default_val.has_value(), - "Error happened at preparing for default values for the argument. The ", - i, - "th arguement of operator", - opname, - " does not have a specified value or default value. "); - } fn = [fn, num_specified_args, args](Stack& stack) { - for (size_t i = num_specified_args.value(); i < args.size(); ++i) { + std::vector out_args; + // The following logic pops and temporarily stores all out arguments + // from the stack (which can be 0 or more, and always appended to the + // schema), in order to push the necessary default values. Finally, the + // out arguments are pushed back into the stack. + for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { + out_args.push_back(stack.back()); + stack.pop_back(); + } + size_t start_index = num_specified_args.value() - out_args.size(); + TORCH_CHECK( + start_index >= 0, + "The number of output arguments is: ", + out_args.size(), + ", which is more then the number of specified arguments: ", + num_specified_args.value()); + for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) { + TORCH_CHECK( + args[i].default_value().has_value(), + "Error happened at preparing for default values for the argument. The ", + i, + "th argument ", + args[i].name(), + " does not have a specified value or default value. "); + stack.push_back(args[i].default_value()); } + stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); fn(stack); }; } diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index db9f0b8c20cf5..99be225255ffb 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -85,8 +85,8 @@ using caffe2::serialize::ReadAdapterInterface; OpCode parseOpCode(const char* str); -IValue expect_field( - IValue tup, +const IValue& expect_field( + const IValue& tup, const std::string& expected_name, size_t entry) { auto row = tup.toTuple()->elements().at(entry).toTuple(); @@ -317,16 +317,15 @@ void BytecodeDeserializer::parseMethods( caffe2::serialize::kMinSupportedBytecodeVersion <= model_version && // NOLINTNEXTLINE(clang-diagnostic-sign-compare) model_version <= caffe2::serialize::kMaxSupportedBytecodeVersion, - "Lite Interpreter verson number does not match. ", + "Lite Interpreter version number does not match. ", "The model version must be between ", caffe2::serialize::kMinSupportedBytecodeVersion, " and ", caffe2::serialize::kMaxSupportedBytecodeVersion, - "But the model version is ", + " but the model version is ", model_version); - bool has_debug_handles = debug_handles.has_value(); - if (has_debug_handles) { + if (debug_handles) { TORCH_CHECK( debug_handles->size() == vals.size(), "The numbers of bytecode values and debug info values do not match."); @@ -340,12 +339,11 @@ void BytecodeDeserializer::parseMethods( const auto& element = vals[i]; const auto& m_tuple = element.toTuple()->elements(); const std::string& function_name = m_tuple[0].toStringRef(); - IValue codeTable = m_tuple[1]; - auto schemaTable = // older files do not store function schema + const IValue& codeTable = m_tuple[1]; + const IValue* schemaTable = // older files do not store function schema (model_version > 0x4L || (model_version == 0x4L && m_tuple.size() >= 3)) - ? at::optional{m_tuple[2]} - : at::nullopt; - + ? &m_tuple[2] + : nullptr; auto function = std::make_unique(c10::QualifiedName(function_name)); @@ -369,8 +367,8 @@ void BytecodeDeserializer::parseMethods( expect_field(codeTable, "register_size", BYTECODE_INDEX_REGISTER_SIZE) .toInt(); - std::vector debug_handles_list; - if (has_debug_handles) { + c10::List debug_handles_list; + if (debug_handles) { const auto& debug_handles_element = (*debug_handles)[i]; const auto& debug_handles_m_tuple = debug_handles_element.toTuple()->elements(); @@ -379,22 +377,21 @@ void BytecodeDeserializer::parseMethods( TORCH_CHECK( debug_info_function_name == function_name, "The function names in the bytecode table and the debug info table do not match."); - IValue debug_handles_table = debug_handles_m_tuple[1]; + const IValue& debug_handles_table = debug_handles_m_tuple[1]; debug_handles_list = (expect_field( debug_handles_table, "function_debug_handles", BYTECODE_INDEX_MODULE_DEBUG_HANDLES) .toTuple() ->elements())[0] - .toList() - .vec(); + .toIntList(); TORCH_CHECK( debug_handles_list.size() == ins_list.size(), "The numbers of instructions and debug handles strings do not match."); } for (const auto j : c10::irange(ins_list.size())) { - auto ins_item = ins_list[j].toTuple()->elements(); + const auto& ins_item = ins_list[j].toTuple()->elements(); TORCH_CHECK( ins_item.size() == 3, "There should be three parts in an instruction. The function name is ", @@ -402,8 +399,8 @@ void BytecodeDeserializer::parseMethods( OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str()); int X = ins_item[1].toInt(); int N = ins_item[2].toInt(); - if (has_debug_handles) { - int64_t debug_handle = debug_handles_list[j].toInt(); + if (debug_handles) { + int64_t debug_handle = debug_handles_list[j]; function->append_instruction(op_code, X, N, debug_handle); } else { function->append_instruction(op_code, X, N); @@ -451,14 +448,9 @@ void BytecodeDeserializer::parseMethods( const auto& type = resolveTypeName( (expect_field(argTable, "type", BYTECODE_INDEX_ARGUMENT_TYPE)) .toStringRef()); - auto default_value = expect_field( - argTable, - "default_value", - BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE) - .toIValue(); - auto arg = - c10::Argument(name, type, c10::nullopt /*N*/, default_value); - args.emplace_back(std::move(arg)); + const IValue& default_value = expect_field( + argTable, "default_value", BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE); + args.emplace_back(name, type, c10::nullopt /*N*/, default_value); } return args; }; @@ -522,15 +514,18 @@ mobile::Module BytecodeDeserializer::deserialize( // being a Tuple (int, table), and the integer stands for the bytecode version // number. The rest of the elements are the same as before. // - auto bvals = readArchive("bytecode", mcu).toTuple()->elements(); + auto bvals = std::move(*readArchive("bytecode", mcu).toTuple()).elements(); c10::optional> debug_handles; + bool has_debug_handles{false}; if (reader_->hasRecord("mobile_debug_handles.pkl")) { debug_handles = readArchive("mobile_debug_handles", mcu).toTuple()->elements(); + has_debug_handles = true; } parseMethods(bvals, debug_handles, *mcu); auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu); + m.setHasDebugHandles(has_debug_handles); #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_); m.setDebugTable(std::move(debug_table)); diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 275b84beba97b..ab558cd2bf5e0 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -57,15 +57,19 @@ bool InterpreterState::run(Stack& stack) { auto inst_with_handle = code_->instructions_with_handles_.at(pc); Instruction inst = inst_with_handle.instruction; DebugHandle debug_handle = inst_with_handle.debug_handle; + // If no valid debug handle found then just log pc. + // This is possible when we did not save debug handles + debug_handle = debug_handle == -1 ? pc : debug_handle; - // std::cout << "RUNNING " << pc << " " << code_->instructions_[pc]; - // if (inst.op == OP) { - // std::cout << ", " << code_->op_names_[inst.X].name; - // if (!code_->op_names_[inst.X].overload_name.empty()) { - // std::cout << "." << code_->op_names_[inst.X].overload_name; - // } - // } - // std::cout << std::endl; + // std::cout << "RUNNING " << pc << " " + // << code_->instructions_with_handles_[pc].instruction; + // if (inst.op == OP) { + // std::cout << ", " << code_->op_names_[inst.X].name; + // if (!code_->op_names_[inst.X].overload_name.empty()) { + // std::cout << "." << code_->op_names_[inst.X].overload_name; + // } + // } + // std::cout << std::endl; // TODO(iliacher): remove the workaround after RecordFunction is in // Dispatcher diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index c04d9f74b7378..c74ca138d848a 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -145,8 +145,7 @@ std::string Module::getCallStack(const int64_t debug_handle) const { // We really need to change this part, so in the next step for profiling support // for delegates, the first thing will be to rewrite how profiling is done // for lite interpreter. -std::string Module::get_forward_method_debug_info(size_t pc) const { - auto debug_handle = find_method("forward")->get_debug_handle(pc); +std::string Module::get_forward_method_debug_info(int64_t debug_handle) const { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) return getDebugTable().getModuleHierarchyInfo( debug_handle, getTopModuleTypeName(*this)); diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h index 73637aa4584a0..6102aa517df66 100644 --- a/torch/csrc/jit/mobile/module.h +++ b/torch/csrc/jit/mobile/module.h @@ -78,7 +78,7 @@ class TORCH_API Module { } const std::vector parameters() const; const std::map named_parameters() const; - std::string get_forward_method_debug_info(size_t pc) const; + std::string get_forward_method_debug_info(int64_t debug_handle) const; std::string getModuleHierarchy(const int64_t debug_handle) const; std::string getCallStack(const int64_t debug_handle) const; /// Enables "training" mode. @@ -115,11 +115,20 @@ class TORCH_API Module { return debug_table_; } + void setHasDebugHandles(bool has_debug_handles) { + has_debug_handles_ = has_debug_handles; + } + + bool hasDebugHandles() const { + return has_debug_handles_; + } + private: c10::intrusive_ptr object_; std::unordered_map metadata_; std::shared_ptr cu_; MobileDebugTable debug_table_; + bool has_debug_handles_; }; } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/profiler_edge.cpp b/torch/csrc/jit/mobile/profiler_edge.cpp index bcd5a6258ee7c..162e43f0982a6 100644 --- a/torch/csrc/jit/mobile/profiler_edge.cpp +++ b/torch/csrc/jit/mobile/profiler_edge.cpp @@ -2,7 +2,6 @@ #include #include -namespace profiler = torch::autograd::profiler; namespace torch { namespace jit { namespace mobile { @@ -27,17 +26,26 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( if (with_modules || with_stack) { auto post_processing = [this, with_stack, with_modules]( std::vector& events) { + std::string no_debug_info("Model was not saved with debug information"); for (auto& e : events) { if (with_modules) { // Since KinetoEvents's module hierarchy takes vector of strings we // just construct a temporary vector using one string element - e.moduleHierarchy(std::vector( - {this->m_.getModuleHierarchy(e.debugHandle())})); + if (this->m_.hasDebugHandles()) { + e.moduleHierarchy(std::vector( + {this->m_.getModuleHierarchy(e.debugHandle())})); + } else { + e.moduleHierarchy(std::vector({no_debug_info})); + } } else if (with_stack) { // Since KinetoEvents's stack trace takes vector of strings we just // construct a temporary vector using one string element - e.stack(std::vector( - {this->m_.getCallStack(e.debugHandle())})); + if (this->m_.hasDebugHandles()) { + e.stack(std::vector( + {this->m_.getCallStack(e.debugHandle())})); + } else { + e.stack(std::vector({no_debug_info})); + } } } }; @@ -55,8 +63,33 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( trace_file_name_ = fname; } +const std::unique_ptr& KinetoEdgeCPUProfiler:: + disableProfiler() { + TORCH_CHECK( + !profiler_result_, + "KinetoEdgeCPUProfiler already disabled. " + "To get list of events use getProfilerResults()"); + profiler_result_ = profiler::disableProfiler(); + return profiler_result_; +} + +const std::unique_ptr& KinetoEdgeCPUProfiler:: + getProfilerResult() { + TORCH_CHECK( + profiler_result_, + "KinetoEdgeCPUProfiler has not been disabled. " + "use disableProfiler() API first, which returns the ProfilerResult."); + return profiler_result_; +} + KinetoEdgeCPUProfiler::~KinetoEdgeCPUProfiler() { - profiler::disableProfiler()->save(trace_file_name_); + if (!trace_file_name_.empty()) { + if (profiler_result_) { + profiler_result_->save(trace_file_name_); + } else { + profiler::disableProfiler()->save(trace_file_name_); + } + } } } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index a245034e34f9b..ef37e01ed4c71 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -2,6 +2,7 @@ #include #include +namespace profiler = torch::autograd::profiler; namespace torch { namespace jit { namespace mobile { @@ -53,6 +54,9 @@ class TORCH_API KinetoEdgeCPUProfiler { const bool with_flops = false, const bool with_modules = false); + const std::unique_ptr& disableProfiler(); + const std::unique_ptr& getProfilerResult(); + ~KinetoEdgeCPUProfiler(); private: @@ -62,6 +66,7 @@ class TORCH_API KinetoEdgeCPUProfiler { */ const mobile::Module& m_; std::string trace_file_name_; + std::unique_ptr profiler_result_; }; } // namespace mobile } // namespace jit diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 42814e5fe5aad..6b955ab6454a7 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -42,6 +42,17 @@ class TypeParser { return simpleTypeIt->second; } else if (token == "List") { return CreateSingleElementType(); + } else if (token == "Union") { + std::vector types; + expect("["); + while (cur() != "]") { + types.emplace_back(parse()); + if (cur() != "]") { + expect(","); + } + } + expect("]"); + return UnionType::create(types); } else if (token == "Optional") { return CreateSingleElementType(); } else if (token == "Future") { diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 815a1bc0ea649..944e27805cf18 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -109,11 +109,11 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) { RegisterOperators mm_tree_reduction_reg({Operator( "prim::MMTreeReduce(...) -> Tensor", - [](Stack* stack) { + [](Stack& stack) { auto num_inputs = pop(stack).toInt(); std::vector inputs; inputs.reserve(num_inputs); - for (auto it = stack->end() - num_inputs; it != stack->end(); ++it) { + for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) { inputs.push_back(std::move(*it).toTensor()); } drop(stack, num_inputs); @@ -320,11 +320,11 @@ RegisterOperators mm_batch_side_reg({Operator( [](const Node* node) -> Operation { size_t num_other_side_inputs = node->inputs().size() - 1; Side single_side = static_cast(node->i(Symbol::attr("side"))); - return [num_other_side_inputs, single_side](Stack* stack) { + return [num_other_side_inputs, single_side](Stack& stack) { at::Tensor side_input; std::vector other_side_inputs; other_side_inputs.reserve(num_other_side_inputs); - for (auto it = stack->end() - num_other_side_inputs; it != stack->end(); + for (auto it = stack.end() - num_other_side_inputs; it != stack.end(); ++it) { other_side_inputs.push_back(std::move(*it).toTensor()); } @@ -343,18 +343,18 @@ RegisterOperators mm_batch_side_reg({Operator( mm_out, num_other_side_inputs, /*dim=*/single_side == Side::LHS ? 1 : 0); - stack->insert( - stack->end(), + stack.insert( + stack.end(), std::make_move_iterator(outputs.begin()), std::make_move_iterator(outputs.end())); } else { if (single_side == Side::LHS) { for (at::Tensor& other : other_side_inputs) { - stack->emplace_back(side_input.mm(other)); + stack.emplace_back(side_input.mm(other)); } } else { for (at::Tensor& other : other_side_inputs) { - stack->emplace_back(other.mm(side_input)); + stack.emplace_back(other.mm(side_input)); } } } diff --git a/torch/csrc/jit/passes/concat_opt.cpp b/torch/csrc/jit/passes/concat_opt.cpp index aa2573ebb42f2..81c8a6745007a 100644 --- a/torch/csrc/jit/passes/concat_opt.cpp +++ b/torch/csrc/jit/passes/concat_opt.cpp @@ -497,95 +497,5 @@ void ExpandConcatAndEliminateRedundancy(const std::shared_ptr& graph) { GRAPH_DUMP("After expanding Concat and eliminating redundancy", graph); } -namespace { - -class VariadicCatUpdater { - public: - explicit VariadicCatUpdater(std::shared_ptr graph) - : graph_(std::move(graph)) {} - - bool run() { - collectCatNodes(graph_->block()); - bool changed = false; - for (auto c : cat_nodes_) { - changed = replaceWithVariadicCat(c) || changed; - } - return changed; - } - - private: - void collectCatNodes(Block* block) { - for (auto node : block->nodes()) { - if (node->kind() == aten::cat) { - cat_nodes_.push_back(node); - } - for (Block* b : node->blocks()) { - collectCatNodes(b); - } - } - } - - bool replaceWithVariadicCat(Node* cat) { - if (cat->input(0)->node()->kind() != prim::ListConstruct) { - return false; - } - auto list = cat->input(0)->node(); - // We do not transform cat ops whose list input can not be moved to the - // position before cat. This in turn implies that there is some mutation - // of the input list before cat. - if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, cat)) { - return false; - } - std::vector inputs = list->inputs().vec(); - inputs.push_back(cat->input(1)); - auto var_cat = cat->owningGraph()->create(prim::VarConcat, inputs); - GRAPH_UPDATE("Adding\n", *var_cat); - var_cat->insertBefore(cat); - GRAPH_UPDATE("Replacing\n", *cat, "with\n", *var_cat); - cat->output()->replaceAllUsesWith(var_cat->output()); - GRAPH_UPDATE("Deleting\n", *cat); - cat->destroy(); - if (!list->hasUses()) { - GRAPH_UPDATE("Deleting\n", *list); - list->destroy(); - } - return true; - } - - AliasDb* getOrCreateAliasDb() { - if (!aliasDb_) { - aliasDb_ = std::make_unique(graph_); - } - return aliasDb_.get(); - } - - std::shared_ptr graph_; - std::unique_ptr aliasDb_ = nullptr; - - std::vector cat_nodes_; -}; - -} // namespace - -bool UseVariadicCat(const std::shared_ptr& graph) { - GRAPH_DUMP("Before VariadicCat", graph); - bool changed = VariadicCatUpdater(graph).run(); - if (changed) { - GRAPH_DUMP("After VariadicCat", graph); - } - return changed; -} - -bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr& graph) { - bool changed_in_last_iter = true; - bool changed = false; - while (changed_in_last_iter) { - changed_in_last_iter = RemoveListMutation(graph); - changed_in_last_iter = changed_in_last_iter || UseVariadicCat(graph); - changed = changed || changed_in_last_iter; - } - return changed; -} - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/concat_opt.h b/torch/csrc/jit/passes/concat_opt.h index b82dc25e612a4..ef4d9432438e6 100644 --- a/torch/csrc/jit/passes/concat_opt.h +++ b/torch/csrc/jit/passes/concat_opt.h @@ -13,12 +13,5 @@ TORCH_API bool EliminateConcatCommonInputs(const std::shared_ptr& graph); TORCH_API void ExpandConcatAndEliminateRedundancy( const std::shared_ptr& graph); -// Replaces the `aten::cat` ops in the given graph with variadic cat ops. -// Returns true if the graph is modified. -TORCH_API bool UseVariadicCat(const std::shared_ptr& graph); - -TORCH_API bool RemoveListMutationAndUseVariadicCat( - const std::shared_ptr& graph); - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index a7f831abd88f6..3a28eaeba46e6 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -78,7 +78,7 @@ c10::optional> runNodeIfInputsAreConstant( try { auto op = n->getOperation(); - op(&stack); + op(stack); } catch (...) { return c10::nullopt; } diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp index 7f935a1c1cbd5..0706c9c14ae98 100644 --- a/torch/csrc/jit/passes/decompose_ops.cpp +++ b/torch/csrc/jit/passes/decompose_ops.cpp @@ -59,7 +59,7 @@ bool isDecomposableNorm(Node* normalize_op) { RegisterOperators reg_ops( {Operator( "aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)", - [](Stack* stack) { + [](Stack& stack) { const int64_t ndim = pop(stack).toInt(); auto self = pop(stack).toTensor(); c10::SmallVector sizes(ndim, 1); @@ -70,7 +70,7 @@ RegisterOperators reg_ops( aliasAnalysisFromSchema()), Operator( "aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)", - [](Stack* stack) { + [](Stack& stack) { const int64_t normalized_ndim = pop(stack).toInt(); auto input_shape = pop(stack).toIntList(); auto self = pop(stack).toTensor(); diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 063b867319629..0debc97ac8241 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -89,7 +90,9 @@ class AttributePropagator { }; auto applyOptimizations = [](std::shared_ptr& subgraph) { runOptimization( - subgraph, /* unroll? */ false, /* const_prop_user_classes? */ false); + subgraph, + /* unroll_non_constant_loops? */ false, + /* const_prop_user_classes? */ false); LowerSimpleTuples(subgraph); }; diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index e6faf90b6f2b6..542e136280520 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -1,9 +1,11 @@ #include #include #include +#include #include #include #include + #include #include #include @@ -33,6 +35,7 @@ #if AT_MKLDNN_ENABLED() #include #include +#include #include #include #endif @@ -182,7 +185,8 @@ void InplaceMKLDNNSubgraph(std::shared_ptr graph) { if (k == aten::relu || k == aten::sigmoid || k == aten::dropout || k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid || k == prim::MKLDNNHardTanh || k == aten::tanh || - k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul")) { + k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul") || + k == Symbol::prim("MKLDNNLayerNorm")) { if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) { continue; } @@ -231,7 +235,7 @@ void InplaceMKLDNNSubgraph(std::shared_ptr graph) { Operation createUnaryOp( std::function aten_op, bool inplace = false) { - return [aten_op, inplace](Stack* stack) { + return [aten_op, inplace](Stack& stack) { auto a = pop(stack).toTensor(); c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); // we cast `a` to an `ideep::tensor`, so we can get at its descriptor @@ -271,8 +275,35 @@ Operation createUnaryOp( }; } +void MKLDNNLayerNormOp(Stack& stack, bool inplace) { + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + + // enable_cudnn not used + pop(stack); + auto eps = pop(stack).toDouble(); + + Tensor bias{}; + Tensor weight{}; + auto bias_ival = pop(stack); + TORCH_INTERNAL_ASSERT(bias_ival.isTensor()); + bias = bias_ival.toTensor(); + + auto weight_ival = pop(stack); + TORCH_INTERNAL_ASSERT(weight_ival.isTensor()); + weight = weight_ival.toTensor(); + + auto shape = pop(stack).toIntVector(); + auto input = pop(stack).toTensor(); + + at::Tensor dst, mean, rstd; + std::tie(dst, mean, rstd) = + at::native::mkldnn_layer_norm_last_index_weight_bias_f32( + input, shape, weight, bias, eps, inplace); + push(stack, dst); +}; + Operation BroadOp(const Node* node) { - return [](Stack* stack) { + return [](Stack& stack) { auto b = pop(stack).toTensor(); auto a = pop(stack).toTensor(); auto b_size = b.sizes(); @@ -437,9 +468,20 @@ const RegisterOperators BroadOpReg({ AliasAnalysisKind::INTERNAL_SPECIAL_CASE), }); +const RegisterOperators MKLDNNLayerNormOpReg({ + torch::jit::Operator( + "prim::MKLDNNLayerNorm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", + [](Stack& stack) { MKLDNNLayerNormOp(stack, false); }, + AliasAnalysisKind::FROM_SCHEMA), + torch::jit::Operator( + "prim::MKLDNNLayerNorm_(Tensor(a!) input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor(a!)", + [](Stack& stack) { MKLDNNLayerNormOp(stack, true); }, + AliasAnalysisKind::FROM_SCHEMA), +}); + Operation ConstantMKLDNNTensorOp(const Node* node) { const auto& t = node->t(attr::value); - return [t](Stack* stack) { + return [t](Stack& stack) { push(stack, t); return 0; }; @@ -467,7 +509,7 @@ jit::RegisterOperators reg_fut_ops({ // XXX: this follows the schema convention of conv2d/conv3d, not // aten::mkldnn_convolution, which is different for some reason! "prim::mkldnn_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", - [](jit::Stack* stack) { + [](jit::Stack& stack) { int64_t groups = pop(stack).toInt(); auto dilation = pop(stack).toIntVector(); auto padding = pop(stack).toIntVector(); @@ -516,7 +558,7 @@ jit::RegisterOperators reg_fut_ops({ // in default bindings jit::Operator( "prim::MKLDNNScalarMul(Tensor self, Scalar other) -> Tensor", - [](jit::Stack* stack) { + [](jit::Stack& stack) { c10::impl::ExcludeDispatchKeyGuard edkg( c10::autograd_dispatch_keyset); float other = pop(stack).toScalar().toFloat(); @@ -534,7 +576,7 @@ jit::RegisterOperators reg_fut_ops({ aliasAnalysisFromSchema()), jit::Operator( "prim::MKLDNNScalarMul_(Tensor(a!) self, Scalar other) -> Tensor(a!)", - [](jit::Stack* stack) { + [](jit::Stack& stack) { c10::impl::ExcludeDispatchKeyGuard edkg( c10::autograd_dispatch_keyset); float other = pop(stack).toScalar().toFloat(); @@ -719,6 +761,13 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) { continue; } + if (body_node->matches( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) { + body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNLayerNorm")); + body_node->destroy(); + continue; + } + if (body_node->kind() == aten::hardswish) { body_node->replaceWithNewSymbol(prim::MKLDNNHardSwish); body_node->destroy(); @@ -917,6 +966,16 @@ class MKLDNNSubgraphSlicer { return false; } } + + if (n->matches( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor") && + n->namedInput("weight")->type() != NoneType::get() && + n->namedInput("bias")->type() != NoneType::get()) { + auto norm_shape = + constant_as>(n->namedInput("normalized_shape")); + return norm_shape.has_value() && norm_shape->size() == 1; + } + // unary ops we dont need to prove anything else than // the input is mkldnn supported switch (n->kind()) { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index f7dd466de4ff4..653f9fec08b32 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -183,7 +183,7 @@ struct GraphFuser { return !strict_fuser_check; } if ((*device).is_cpu()) { - return canFuseOnCPU(); + return canFuseOnCPULegacy(); } else if ((*device).is_cuda()) { return canFuseOnGPU(); } else if ((*device).is_xpu()) { @@ -1244,6 +1244,16 @@ void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) { } // anonymous namespace +static bool cpu_fuser_enabled_legacy = false; + +bool canFuseOnCPULegacy() { + return cpu_fuser_enabled_legacy; +} + +void overrideCanFuseOnCPULegacy(bool value) { + cpu_fuser_enabled_legacy = value; +} + void FuseGraph(std::shared_ptr& graph, bool strict_fuser_check) { AliasDb db(graph); GraphFuser(&db, graph->block(), strict_fuser_check).run(); diff --git a/torch/csrc/jit/passes/graph_fuser.h b/torch/csrc/jit/passes/graph_fuser.h index 0cdcc2e20f469..aafb442eafb6f 100644 --- a/torch/csrc/jit/passes/graph_fuser.h +++ b/torch/csrc/jit/passes/graph_fuser.h @@ -5,6 +5,9 @@ namespace torch { namespace jit { +TORCH_API bool canFuseOnCPULegacy(); +TORCH_API void overrideCanFuseOnCPULegacy(bool value); + // NB: Be sure to run DCE before fusion, because dead instructions // can prevent fusion opportunities from being exploited. // On Windows will noop, NYI diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index cc6444e8a9dfd..67637031868c1 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -102,8 +102,10 @@ const std::unordered_map& getOperatorAliasMap() { {aten::divide_, aten::div_}, {aten::multiply, aten::mul}, {aten::multiply_, aten::mul_}, + {aten::linalg_matmul, aten::matmul}, {aten::true_divide, aten::div}, {aten::true_divide_, aten::div_}, + {aten::concat, aten::cat}, {aten::row_stack, aten::vstack}, {aten::swapdims, aten::transpose}, {aten::swapdims_, aten::transpose_}, diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 901844cd62380..76c0674e11fd8 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -266,9 +266,7 @@ c10::optional runTorchBackendForOnnx( if (node->kind() == onnx::Slice) { if (opset_version == ONNX_OPSET_9) { return runTorchSlice_opset9(node, inputTensorValues); - } else if ( - opset_version == ONNX_OPSET_10 || opset_version == ONNX_OPSET_11 || - opset_version == ONNX_OPSET_12 || opset_version == ONNX_OPSET_13) { + } else if (opset_version >= ONNX_OPSET_10) { return runTorchSlice_opset10(node, inputTensorValues); } else { std::cerr << "Warning: Constant folding - unsupported opset version. " @@ -351,7 +349,7 @@ c10::optional runTorchBackendForOnnx( } } else if (node->kind() == onnx::Squeeze) { assert(inputTensorValues.size() == 2 || inputTensorValues.size() == 1); - if (opset_version == ONNX_OPSET_13) { + if (opset_version >= ONNX_OPSET_13) { // Squeeze version 13 input axes is optional, inputTensorValues.size() == // 1 means axes equal to None updated_val = inputTensorValues[0]; @@ -415,13 +413,18 @@ c10::optional runTorchBackendForOnnx( std::vector shape(inputTensorValues[1].sizes()[0], 0); auto shape_a = inputTensorValues[1].accessor(); assert(inputTensorValues[1].sizes()[0] >= 0); + // Set value of allowzero + int64_t allowzero = 0; + if (node->hasAttributeS("allowzero")) { + allowzero = node->i(attr::allowzero); + } for (size_t i = 0; i < (size_t)(inputTensorValues[1].sizes()[0]); ++i) { // All shape dim values should be >= -1 // onnx::Reshape supports a shape dim value to be zero, in // which case the actual dim value remains unchanged. However, // at::reshape does not support shape dim value to be zero assert(shape_a[i] >= -1); - if (shape_a[i] == 0) { + if (shape_a[i] == 0 && !allowzero) { if (i >= inputTensorValues[0].sizes().size()) { throw std::runtime_error( "Dimension with value 0 exceeds the input size dimensions."); diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index 1c54412ccd7a1..8bfb0dd081c39 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -13,6 +13,7 @@ const int ONNX_OPSET_10 = 10; const int ONNX_OPSET_11 = 11; const int ONNX_OPSET_12 = 12; const int ONNX_OPSET_13 = 13; +const int ONNX_OPSET_14 = 14; namespace onnx_constant_fold { diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.cpp b/torch/csrc/jit/passes/onnx/eval_peephole.cpp index 18dea16cb97ae..05afb69ef0f23 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.cpp +++ b/torch/csrc/jit/passes/onnx/eval_peephole.cpp @@ -47,14 +47,20 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) { fuseConvBatchNorm(child_block, valsToParamsMap); } if (it->kind() == onnx::Conv) { - if (it->output()->uses().size() != 1) { + auto oldConv = *it; + if (oldConv->outputs().at(0)->uses().size() != 1) { continue; } - auto bnNode = it->output()->uses()[0].user; + auto bnNode = oldConv->outputs().at(0)->uses()[0].user; if (bnNode->kind() != onnx::BatchNormalization) { continue; } - auto oldConv = *it; + + if (oldConv->outputs().size() != + bnNode->outputs().size()) { // BN layer is not in eval mode + continue; + } + auto epsilon = bnNode->f(attr::epsilon); auto convInputVals = getValues(oldConv, valsToParamsMap); if (convInputVals.size() < 1 || @@ -109,11 +115,8 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) { convB = bnB; } - Node* newConv = - b->owningGraph()->create(onnx::Conv, bnNode->outputs().size()); - for (size_t i = 0; i < newConv->outputs().size(); ++i) { - newConv->outputs()[i]->copyMetadata(bnNode->outputs()[i]); - } + Node* newConv = b->owningGraph()->create(onnx::Conv, 1); + newConv->outputs().at(0)->copyMetadata(bnNode->outputs().at(0)); newConv->copyAttributes(*oldConv); newConv->insertBefore(bnNode); @@ -131,9 +134,7 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) { newConvB->inferTypeFrom(convB); newConv->addInput(newConvB); - bnNode->replaceAllUsesWith(newConv); - bnNode->removeAllInputs(); - it->removeAllInputs(); + bnNode->outputs().at(0)->replaceAllUsesWith(newConv->outputs().at(0)); bnNode->destroy(); it.destroyCurrent(); } diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index abfb547ed5e94..b0a310bfe20ad 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -238,9 +238,7 @@ std::vector FixupONNXLoopNode(Node* node, int opset_version) { auto new_outputs = ConvertSequenceDependencies(node, opset_version); // Copy type of block output to node output. - for (size_t i = 0; i < node->outputs().size(); ++i) { - node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type()); - } + FixupONNXControlflowNodeOutputs(node); TORCH_INTERNAL_ASSERT(output_size == new_outputs.size()); return new_outputs; } @@ -347,25 +345,90 @@ void ONNXFixupUninitializedOutput(Node* node) { graph, else_block, else_block_output, then_block_output); if_node->outputs()[i]->setType(else_block->outputs()[i]->type()); } - auto then_tensor_type = - then_block->outputs().at(i)->type()->castRaw(); - auto else_tensor_type = - else_block->outputs().at(i)->type()->castRaw(); - if (then_tensor_type && else_tensor_type) { - const auto& then_shape = then_tensor_type->symbolic_sizes(); - const auto& else_shape = else_tensor_type->symbolic_sizes(); - std::vector<::c10::ShapeSymbol> dims; - if (then_shape.rank() && else_shape.rank() && - then_shape.rank() == else_shape.rank()) { - for (const auto j : c10::irange(then_shape.rank().value())) { - if (then_shape[j] == else_shape[j]) { - dims.emplace_back(then_shape[j]); - } else { - dims.emplace_back(::c10::ShapeSymbol::newSymbol()); - } + } +} + +void ONNXMergeIfBlockOutputShapes(Node* node) { + TORCH_INTERNAL_ASSERT(node->kind() == ::c10::onnx::If); + Block* then_block = node->blocks().at(0); + Block* else_block = node->blocks().at(1); + + TORCH_INTERNAL_ASSERT( + then_block->outputs().size() == else_block->outputs().size()) + + auto findCommonShape = + [](const ::c10::SymbolicShape& a, + const ::c10::SymbolicShape& b) -> ::c10::SymbolicShape { + std::vector<::c10::ShapeSymbol> dims; + if (a.rank() && b.rank() && a.rank() == b.rank()) { + for (const auto j : c10::irange(a.rank().value())) { + if (a[j] == b[j]) { + dims.emplace_back(a[j]); + } else { + dims.emplace_back(::c10::ShapeSymbol::newSymbol()); } - if_node->output(i)->setType( - then_tensor_type->withSymbolicShapes(::c10::SymbolicShape(dims))); + } + return ::c10::SymbolicShape(dims); + } + if (a.rank() && a.rank().value() > 0) { + return a; + } + if (b.rank() && b.rank().value() > 0) { + return b; + } + + return ::c10::SymbolicShape(); + }; + + auto mergeTensorType = + [&findCommonShape](TensorTypePtr a, TensorTypePtr b) -> TensorTypePtr { + if (a && b) { + const auto& a_shape = a->symbolic_sizes(); + const auto& b_shape = b->symbolic_sizes(); + auto commonShape = findCommonShape(a_shape, b_shape); + return a->withSymbolicShapes(commonShape); + } else if (a) { + return a; + } else if (b) { + return b; + } + return nullptr; + }; + + auto mergeListType = [&mergeTensorType]( + ListTypePtr a, ListTypePtr b) -> ListTypePtr { + if (a && b) { + auto a_tensor_type = a->getElementType()->cast(); + auto b_tensor_type = b->getElementType()->cast(); + auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type); + if (tensor_type) { + return a->withContained({tensor_type})->cast(); + } + // Both branches produce ListType without tensor shape. + return a; + } else if (a) { + return a; + } else if (b) { + return b; + } + return nullptr; + }; + + for (const auto i : c10::irange(else_block->outputs().size())) { + auto then_type = then_block->outputs().at(i)->type(); + auto else_type = else_block->outputs().at(i)->type(); + auto then_tensor_type = then_type->cast(); + auto else_tensor_type = else_type->cast(); + auto then_list_type = then_type->cast(); + auto else_list_type = else_type->cast(); + if (then_tensor_type || else_tensor_type) { + if (auto tensor_type = + mergeTensorType(then_tensor_type, else_tensor_type)) { + node->output(i)->setType(tensor_type); + } + } else if (then_list_type || else_list_type) { + if (auto list_type = mergeListType(then_list_type, else_list_type)) { + node->output(i)->setType(list_type); } } } @@ -376,16 +439,13 @@ std::vector FixupONNXIfNode(Node* node, int opset_version) { return node->outputs().vec(); } GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph()); - auto* if_node = node; FixupONNXSubblockOutputs(node); - ONNXFixupUninitializedOutput(if_node); + ONNXFixupUninitializedOutput(node); // Copy type of block output to node output. - for (size_t i = 0; i < node->outputs().size(); ++i) { - node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type()); - } + ONNXMergeIfBlockOutputShapes(node); GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph()); - return if_node->outputs().vec(); + return node->outputs().vec(); } std::vector FixupONNXControlflowNode(Node* n, int opset_version) { @@ -401,5 +461,36 @@ std::vector FixupONNXControlflowNode(Node* n, int opset_version) { } } +void FixupONNXControlflowNodeOutputs(Node* n) { + switch (n->kind()) { + case ::c10::onnx::Loop: { + auto loop_carried_output_size = n->blocks().at(0)->inputs().size() - 2; + for (auto i : c10::irange(n->outputs().size())) { + auto type = n->blocks().at(0)->outputs().at(i + 1)->type(); + if (i < loop_carried_output_size) { + n->output(i)->setType(type); + } else { + if (auto t_type = type->cast()) { + auto sizes = t_type->symbolic_sizes().sizes(); + if (sizes.has_value()) { + sizes.value().emplace( + sizes.value().begin(), c10::ShapeSymbol::newSymbol()); + type = t_type->withSymbolicShapes(sizes.value()); + } + } + n->output(i)->setType(type); + } + } + break; + } + case ::c10::onnx::If: { + ONNXMergeIfBlockOutputShapes(n); + break; + } + default: + break; + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h index fad7611085223..8d33c2dd1fb5e 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h @@ -6,6 +6,7 @@ namespace torch { namespace jit { std::vector FixupONNXControlflowNode(Node* n, int opset_version); +void FixupONNXControlflowNodeOutputs(Node* n); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index ccadf53713466..9c751bbae9e12 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -76,6 +76,7 @@ std::vector getParamAttributes( WithInsertPoint guard(m); std::vector parameterIValues = {}; + std::unordered_set nodesToDestroy; for (auto it = block->nodes().begin(); it != block->nodes().end();) { Node* n = *it; it++; // node n can be destroyed @@ -142,7 +143,7 @@ std::vector getParamAttributes( // This attr is constant for ONNX. auto attrVal = tryInsertConstant(*graph, attr); n->output()->replaceAllUsesWith(*attrVal); - n->destroy(); + nodesToDestroy.emplace(n); } } } @@ -156,6 +157,9 @@ std::vector getParamAttributes( std::end(nextParameterIValues)); } } + for (auto n : nodesToDestroy) { + n->destroy(); + } return parameterIValues; } diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp index 2854c3ab2fe2e..bc646308424b0 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp @@ -4,8 +4,8 @@ namespace torch { namespace jit { bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) { - const auto& source_n = n->sourceRange().source(); - const auto& source_m = m->sourceRange().source(); + const auto source_n = n->sourceRange().source(); + const auto source_m = m->sourceRange().source(); return ( (source_n->text() == source_m->text()) && (source_n->starting_line_no() == source_m->starting_line_no())); diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 913f4dc2b6edb..2cef76a7391ae 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -317,26 +317,33 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { } for (auto input : b->inputs()) { - for (auto use : input->uses()) { - Node* node = use.user; - if (!mr.inplaceOpVariant(node)) { - continue; - } - auto it = std::find(node->inputs().begin(), node->inputs().end(), input); - if (it != node->inputs().end()) { - int index = std::distance(node->inputs().begin(), it); - std::cerr << "Warning: ONNX Preprocess - Removing mutation from node " - << node->kind().toQualString() << " on block input: '" - << (*it)->debugName() << "'. This changes graph semantics." - << std::endl; - - Node* newNode = - addDummyClone(b->owningGraph(), input, false, b->return_node()); - TORCH_INTERNAL_ASSERT(nullptr != newNode); - node->replaceInput(index, newNode->output()); - input->replaceAllUsesAfterNodeWith(node, newNode->output()); + bool needsRestart = false; + do { + needsRestart = false; + for (auto use : input->uses()) { + Node* node = use.user; + if (!mr.inplaceOpVariant(node)) { + continue; + } + auto it = + std::find(node->inputs().begin(), node->inputs().end(), input); + if (it != node->inputs().end()) { + int index = std::distance(node->inputs().begin(), it); + std::cerr << "Warning: ONNX Preprocess - Removing mutation from node " + << node->kind().toQualString() << " on block input: '" + << (*it)->debugName() << "'. This changes graph semantics." + << std::endl; + + Node* newNode = + addDummyClone(b->owningGraph(), input, false, b->return_node()); + TORCH_INTERNAL_ASSERT(nullptr != newNode); + node->replaceInput(index, newNode->output()); + input->replaceAllUsesAfterNodeWith(node, newNode->output()); + needsRestart = true; + break; + } } - } + } while (needsRestart); } } diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 634d8d2e1db99..8ade722fb8bd9 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -180,7 +181,21 @@ void UpdateTorchValueByOnnxValueInfo( } } -bool IsSupportedNode(const Node* n) { +bool IsValidONNXControlflowNode(const Node* n) { + // Skip when block size is zero. This is when the node is being created, + // and doesn't have subblocks attached yet. Run shape inference for these + // nodes later, when the subgraph has already completed shape inferencing. + auto node_kind = n->kind(); + if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) { + if (n->blocks().size() == 0) { + return false; + } + } + + return true; +} + +bool IsValidONNXNode(const Node* n) { auto node_kind = n->kind(); if (!node_kind.is_onnx()) { @@ -188,18 +203,14 @@ bool IsSupportedNode(const Node* n) { return false; } - // Skip when block size is zero. This is when the node is first created, - // doesn't have subblocks attached yet. Run shape inference for these nodes - // when the subgraph has already completed shape inferencing. - if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) { - if (n->blocks().size() == 0) { - return false; - } - for (auto b : n->blocks()) { - for (auto b_n : b->nodes()) { - if (!IsSupportedNode(b_n)) { - return false; - } + if (!IsValidONNXControlflowNode(n)) { + return false; + } + + for (auto b : n->blocks()) { + for (auto b_n : b->nodes()) { + if (!IsValidONNXNode(b_n)) { + return false; } } } @@ -404,8 +415,10 @@ c10::optional ComputeConstantFolding(Node* n, int opset_version) { // When the Reshape node's two inputs are constant, compute the output shape. // The reshape value 0 and -1 are converted to the real value explicitly. std::vector ComputeShapeFromReshape( + Node* n, const std::vector& input_shape, - const std::vector& reshape) { + const std::vector& reshape, + int opset_version) { TORCH_INTERNAL_ASSERT( input_shape.size() > 0 || reshape.size() > 0, "Reshape node should have at least one input size > 0 when constant folding."); @@ -427,6 +440,17 @@ std::vector ComputeShapeFromReshape( auto reshape_size = static_cast(reshape.size()); auto it_0 = std::find(reshape.begin(), reshape.end(), 0); auto reshape_has_zero = it_0 != reshape.end(); + + // Allowzero is set to 0 by default + // When opset version > 14, assign appropriate allowzero value + int allowzero = 0; + if (opset_version >= 14 && n->hasAttributeS("allowzero")) { + allowzero = n->i(attr::allowzero); + if (allowzero == 1 && reshape_has_zero) { + return reshape; + } + } + auto input_shape_size = static_cast(input_shape.size()); auto it_minus_one = std::find(reshape.begin(), reshape.end(), -1); int minus_one_pos = it_minus_one == reshape.end() @@ -594,7 +618,7 @@ c10::optional> GetValueFromListConstructNode( : c10::nullopt; } -void ProcessReshapeNode(Node* n) { +void ProcessReshapeNode(Node* n, int opset_version) { if (ConstantValueMap::HasValue(n->input(1)->debugName())) { auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(n->input(1)->debugName()); @@ -602,8 +626,8 @@ void ProcessReshapeNode(Node* n) { ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown( n->input(0)->debugName()); if (shape_vector_0.has_value()) { - auto final_shape = - ComputeShapeFromReshape(shape_vector_0.value(), shape_temp); + auto final_shape = ComputeShapeFromReshape( + n, shape_vector_0.value(), shape_temp, opset_version); UpdateShapeFromVector(n->output(), final_shape); return; } @@ -865,7 +889,7 @@ void ComputeConstant(Node* n, int opset_version) { break; } case ::c10::onnx::Reshape: { - ProcessReshapeNode(n); + ProcessReshapeNode(n, opset_version); break; } case ::c10::onnx::Gather: { @@ -1297,6 +1321,20 @@ void SpecialPostProcess(Node* n) { } break; } + case ::c10::onnx::If: { + if (!IsValidONNXControlflowNode(n)) { + break; + } + FixupONNXControlflowNodeOutputs(n); + break; + } + case ::c10::onnx::Loop: { + if (!IsValidONNXControlflowNode(n)) { + break; + } + FixupONNXControlflowNodeOutputs(n); + break; + } } } @@ -1378,64 +1416,67 @@ void ONNXShapeTypeInference( int opset_version) { GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); - if (!IsSupportedNode(n)) { - return; - } - // Create a Graph containing only the single node n. - // This graph is later converted to ONNX to run shape inference. - auto n_graph = std::make_shared(); - auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version); - n_graph->insertNode(clone_node); - - // Register all node outputs as graph outputs. - for (auto output : clone_node->outputs()) { - n_graph->registerOutput(output); - } + if (IsValidONNXNode(n)) { + // Create a Graph containing only the single node n. + // This graph is later converted to ONNX to run shape inference. + auto n_graph = std::make_shared(); + auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version); + n_graph->insertNode(clone_node); + + // Register all node outputs as graph outputs. + for (auto output : clone_node->outputs()) { + n_graph->registerOutput(output); + } - // Use scalar_type_analysis without low precision cast - ScalarTypeAnalysisForONNX(n_graph, false, opset_version); + // Use scalar_type_analysis without low precision cast + ScalarTypeAnalysisForONNX(n_graph, false, opset_version); - GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString()); - GRAPH_DEBUG( - "Cloned torch graph to run shape inference: ", n_graph->toString()); - - if (IsGraphValidForInference(n_graph)) { - // TODO: Some ops have conversion happen at Peephole pass. - // The conversion here is incomplete for these ops. - // e.g: ListConstruct, ListUnpack, etc. - std::shared_ptr model_proto; - SymbolDimMap symbol_map; - ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version); + GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString()); GRAPH_DEBUG( - "ONNX graph to run shape inference: ", prettyPrint(*model_proto)); - - // infer shape - try { - onnx::shape_inference::InferShapes(*model_proto); - UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map); - } catch (std::runtime_error& ex) { - // TODO: include this as warning once we have a more consolidated warning - // system. + "Cloned torch graph to run shape inference: ", n_graph->toString()); + + if (IsGraphValidForInference(n_graph)) { + // TODO: Some ops have conversion happen at Peephole pass. + // The conversion here is incomplete for these ops. + // e.g: ListConstruct, ListUnpack, etc. + std::shared_ptr model_proto; + SymbolDimMap symbol_map; + ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version); GRAPH_DEBUG( - "ONNX shape inference fails with: ", - ex.what(), - " on graph: ", - n_graph->toString()); - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - const char shape_err[] = "ShapeInferenceError"; - // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) - const char type_err[] = "TypeInferenceError"; - if ((strstr(ex.what(), shape_err) == nullptr) && - (strstr(ex.what(), type_err) == nullptr)) { - throw; + "ONNX graph to run shape inference: ", prettyPrint(*model_proto)); + + // infer shape + try { + onnx::shape_inference::InferShapes(*model_proto); + UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map); + } catch (std::runtime_error& ex) { + // TODO: include this as warning once we have a more consolidated + // warning system. + GRAPH_DEBUG( + "ONNX shape inference fails with: ", + ex.what(), + " on graph: ", + n_graph->toString()); + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + const char shape_err[] = "ShapeInferenceError"; + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + const char type_err[] = "TypeInferenceError"; + // NOLINTNEXTLINE(modernize-use-nullptr) + if ((strstr(ex.what(), shape_err) == NULL) && + // NOLINTNEXTLINE(modernize-use-nullptr) + (strstr(ex.what(), type_err) == NULL)) { + throw; + } } + GRAPH_DEBUG( + "ONNX graph after shape inference: ", prettyPrint(*model_proto)); } - GRAPH_DEBUG( - "ONNX graph after shape inference: ", prettyPrint(*model_proto)); } SpecialPostProcess(n); - ProcessConstantValueMap(n, opset_version); + if (IsValidONNXNode(n)) { + ProcessConstantValueMap(n, opset_version); + } GRAPH_DEBUG( "Torch graph after shape inference:", n->owningGraph()->toString()); } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 3024811fef6bd..c74c6ee40221a 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -41,7 +41,6 @@ bool mergeTypes( return changed; } - namespace prim { using namespace ::c10::prim; } @@ -289,6 +288,24 @@ class ShapePropagator { return zerodim; } + bool mergeTypes( + ArrayRef lhs, + ArrayRef rhs, + ArrayRef outputs) { + AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size()); + bool changed = false; + for (size_t i = 0; i < lhs.size(); ++i) { + auto old_output_type = outputs[i]->type(); + auto new_type = + unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true); + AT_ASSERT(new_type); + outputs[i]->setType(*new_type); + if (*old_output_type != *outputs[i]->type()) + changed = true; + } + return changed; + } + void broadcastBinary( Node* node, std::vector& types, @@ -411,7 +428,7 @@ class ShapePropagator { // is to uncover any mistakes we could make when editing this code, // and eventually it shouldn't matter, because this phase should be // preceded by schema checking. - op(&stack); + op(stack); AT_ASSERT(stack.size() == node->outputs().size()); for (const auto i : c10::irange(stack.size())) { diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index d4add03506c4f..75305d63e072f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -136,7 +135,7 @@ const OperatorSet& supported_eltwise_set() { "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", // "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor", // "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", TODO: requires 0-dim Tensor - "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", + // "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", @@ -250,15 +249,6 @@ bool isSupported(Node* node) { } // namespace tensorexpr static bool texpr_fuser_enabled_ = true; -static bool texpr_parallel_cpu_enabled = false; - -bool texprParallelCPUEnabled() { - return texpr_parallel_cpu_enabled; -} - -void setTexprParallelCPUEnabled(bool val) { - texpr_parallel_cpu_enabled = val; -} void setTensorExprFuserEnabled(bool val) { texpr_fuser_enabled_ = val; @@ -898,14 +888,7 @@ class TensorExprFuser { return false; } if (device->is_cpu()) { - // CPU fusion is only supported for single-thread. - if (!canFuseOnCPU()) { - return false; - } - if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) { - return true; - } - return false; + return canFuseOnCPU(); } else if (device->is_cuda()) { return canFuseOnGPU(); } else if (device->is_xpu()) { @@ -948,6 +931,14 @@ class TensorExprFuser { "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "aten::matmul(Tensor self, Tensor other) -> Tensor", }; + static const OperatorSet gpu_only_operator_set{ + // On CPU, these are slower and less accurate than ATen kernels, because + // ATen is able to use MKL-VML, whereas the fuser currently can't. The + // fuser uses sleef instead because sleef provides functions that operate + // on vectors, instead of large buffers. + "aten::erf(Tensor self) -> Tensor", + "aten::erfc(Tensor self) -> Tensor", + }; static const OperatorSet pow{ "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", }; @@ -975,7 +966,9 @@ class TensorExprFuser { // but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we // additionally disable it until we either move to a more stable version // or find workarounds. - if (*st == c10::ScalarType::Half) { + if ((*st == c10::ScalarType::Half || + *st == c10::ScalarType::BFloat16) && + *device == c10::kCPU) { return false; } @@ -1026,6 +1019,17 @@ class TensorExprFuser { } } + // Operator is only supported on GPU. + if (node->isMemberOf(gpu_only_operator_set)) { + auto device = tensorexpr::pickDeviceType(node->inputs()); + if (!device) { + device = tensorexpr::pickDeviceType(node->outputs()); + } + if (!device || !device->is_cuda()) { + return false; + } + } + if (node->kind() == aten::to) { // only support same-device conversion auto device = tensorexpr::pickDeviceType(node->inputs()); @@ -1096,8 +1100,7 @@ class TensorExprFuser { // All tensor types should be known. return false; } - if (c10::isComplexType(*st) || c10::isQIntType(*st) || - *st == c10::ScalarType::BFloat16) { + if (c10::isComplexType(*st) || c10::isQIntType(*st)) { return false; } } @@ -1297,9 +1300,9 @@ void FuseTensorExprs( Operation createTensorExprOp(const Node* node) { auto kernel = std::make_shared(node->g(attr::Subgraph)); - return [kernel](Stack* stack) { + return [kernel](Stack& stack) { RECORD_FUNCTION("TensorExpr", std::vector()); - kernel->run(*stack); + kernel->run(stack); return 0; }; } diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h index 3f6538b7e587a..254aebd91d12f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.h +++ b/torch/csrc/jit/passes/tensorexpr_fuser.h @@ -24,8 +24,6 @@ TORCH_API void setTensorExprFuserEnabled(bool val); TORCH_API bool tensorExprFuserEnabled(); TORCH_API bool setTexprReductionsEnabled(bool value); TORCH_API bool texprReductionsEnabled(); -TORCH_API bool texprParallelCPUEnabled(); -TORCH_API void setTexprParallelCPUEnabled(bool val); TORCH_API void RemoveProfileNodesAndSpecializeTypes( std::shared_ptr& graph); diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index cd894b46ff69b..ae3a962509994 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -253,7 +253,7 @@ void checkAliasAnnotation( const auto inputsDeepCopy = deepCopy(stack); // Run the op - node->getOperation()(&stack); + node->getOperation()(stack); const auto outputs = std::move(stack); diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp index 6a880c86e4102..3f6cc8079b6f9 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.cpp +++ b/torch/csrc/jit/passes/utils/memory_dag.cpp @@ -8,6 +8,7 @@ namespace torch { namespace jit { namespace { + void makePointerToImpl(Element* from, Element* to) { from->pointsTo.set(to->index); to->pointedFrom.set(from->index); @@ -131,11 +132,13 @@ Element* MemoryDAGBuilder::makeFreshValue(const Value* v) { return makeFreshValueImpl(v, indexToElementMap_); } +// This function builds up a bitset representing the "alias set" for +// `e` (`MemoryLocations` is just a typedef'd c10::SparseBitVector). const MemoryLocations& MemoryDAG::getMemoryLocations(const Element* e) const { // Note on cache invalidation: all mutation should occur through - // MemoryDAGBuilder. Thus, once we consume the builder to create an immutable - // MemoryDAG, we can cache here without worrying that we might potentially get - // invalidated. + // MemoryDAGBuilder. Thus, once we consume the builder to create an + // immutable MemoryDAG, we can cache here without worrying that we + // might potentially get invalidated. if (e->cachedMemoryLocations_) { return *e->cachedMemoryLocations_; } @@ -174,7 +177,6 @@ void MemoryDAG::setWildcards( makePointerToImpl(from, wildcardElement); } } - // Track which memory locations we edited with a new pointer to the wildcard // element. cacheUpdates[wildcardElement] |= pointeeSet; @@ -189,7 +191,6 @@ void MemoryDAG::setWildcards( for (const std::unique_ptr& e : this->indexToElementMap_) { if (e->values.empty()) { // This element is a wildcard element, we can skip it. - TORCH_INTERNAL_ASSERT(e->pointsTo.empty()); continue; } diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h index 38432ff69c9c1..3e3a19c31729c 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.h +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -1,9 +1,12 @@ #pragma once +#include #include #include #include #include +#include +#include #include #include #include @@ -20,6 +23,9 @@ struct Element; struct Value; class MemoryDAG; +using TypePtr = std::shared_ptr; +using AliasTypeSet = std::vector; + /** * Helper to build up the points-to graph. * @@ -38,13 +44,15 @@ class TORCH_API MemoryDAGBuilder { void addToContainedElements(Element* contained, Element* container); - // Make a fresh element (i.e. an element that doesn't point to anything) and + // Make a fresh Element (i.e. an Element that doesn't point to anything) and // return it. Element* makeFreshValue(const Value* v); friend MemoryDAG; private: + // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses + // the map to construct the `MemoryDAG` std::vector> indexToElementMap_; }; @@ -54,8 +62,8 @@ class TORCH_API MemoryDAGBuilder { // AliasDb to provide a higher-level API. // // We maintain a DAG where: -// - Vertices (called "elements") represent values and -// other aliasing entities (e.g. like the stuff inside a list) +// - Vertices (called "Elements") represent Values and +// other aliasing entities (e.g. the stuff inside a list) // - Edges represent a "points-to" relationship. // // Leaves in this DAG are entities that don't point to anything, and thus @@ -80,7 +88,7 @@ class TORCH_API MemoryDAG { bool mayAlias(const Element* a, const Element* b) const; bool mayAlias(Element* a, Element* b) const; - // Does a hold reference to any memory that is stored in elem, or vice versa? + // Does `a` hold reference to any memory that is stored in `b`, or vice versa? bool mayContainAlias(const Element* a, const Element* b) const; bool mayContainAlias(Element* a, Element* b) const; @@ -96,12 +104,13 @@ class TORCH_API MemoryDAG { MemoryLocations& cont) const; /** - * The following methods are special cases where we need to reach mutate the + * The following methods are special cases where we need to mutate the * internals of MemoryDAG for efficiency reasons. Don't call them unless you * know what you're doing! In particular, don't add new mutating methods * without ensuring that you are maintaining cache consistency for memory * locations. */ + // Adding wildcards can trigger extremely expensive cache invalidations. This // method adds them in a more efficient cache-aware way. void setWildcards( @@ -117,9 +126,10 @@ class TORCH_API MemoryDAG { std::vector> indexToElementMap_; }; -// `Element` represents the vertex in the points-to graph. It represents -// anything that could have an aliasing relationship, mostly IR `Value`s, but -// also the "inside of a list", or wildcards. +// `Element` represents a vertex in the points-to graph. It represents +// anything that could have an aliasing relationship--mostly IR +// `Value`s, but also wildcards or the type inside a container (e.g. `T` +// in `List[T]`) struct Element { Element(const Value* value_, unsigned index_); // wildcard constructor diff --git a/torch/csrc/jit/passes/variadic_ops.cpp b/torch/csrc/jit/passes/variadic_ops.cpp new file mode 100644 index 0000000000000..a827d3a2371d8 --- /dev/null +++ b/torch/csrc/jit/passes/variadic_ops.cpp @@ -0,0 +1,156 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +namespace { + +class VariadicUpdater { + public: + explicit VariadicUpdater( + std::shared_ptr graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx = 0) + : graph_(std::move(graph)), + op_(op), + variadic_op_(variadic_op), + list_idx_(list_idx) {} + + bool run() { + collectOpNodes(graph_->block()); + bool changed = false; + for (auto n : op_nodes_) { + changed |= replaceWithVariadicOp(n); + } + return changed; + } + + private: + void collectOpNodes(Block* block) { + for (auto node : block->nodes()) { + if (node->kind() == op_) { + op_nodes_.push_back(node); + } + for (Block* b : node->blocks()) { + collectOpNodes(b); + } + } + } + + bool replaceWithVariadicOp(Node* op_node) { + const size_t num_inputs = op_node->inputs().size(); + TORCH_CHECK(list_idx_ < num_inputs); + if (op_node->input(list_idx_)->node()->kind() != prim::ListConstruct) { + return false; + } + auto list = op_node->input(list_idx_)->node(); + const size_t list_len = list->inputs().size(); + + // We do not transform ops whose list input can not be moved to the + // position before op. This in turn implies that there is some mutation + // of the input list before op. + if (!getOrCreateAliasDb()->couldMoveBeforeTopologically(list, op_node)) { + return false; + } + + // Construct new inputs + std::vector inputs; + inputs.reserve(num_inputs + list_len - 1); + inputs.insert( + inputs.end(), + op_node->inputs().begin(), + op_node->inputs().begin() + list_idx_); + inputs.insert(inputs.end(), list->inputs().begin(), list->inputs().end()); + inputs.insert( + inputs.end(), + op_node->inputs().begin() + list_idx_ + 1, + op_node->inputs().end()); + + auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs); + GRAPH_UPDATE("Adding\n", *var_op_node); + var_op_node->insertBefore(op_node); + GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node); + op_node->output()->replaceAllUsesWith(var_op_node->output()); + GRAPH_UPDATE("Deleting\n", *op_node); + op_node->destroy(); + if (!list->hasUses()) { + GRAPH_UPDATE("Deleting\n", *list); + list->destroy(); + } + return true; + } + + AliasDb* getOrCreateAliasDb() { + if (!aliasDb_) { + aliasDb_ = std::make_unique(graph_); + } + return aliasDb_.get(); + } + + std::shared_ptr graph_; + std::unique_ptr aliasDb_ = nullptr; + + std::vector op_nodes_; + + NodeKind op_; + NodeKind variadic_op_; + + size_t list_idx_; +}; + +} // namespace + +bool UseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx) { + const std::string pass_name = std::string("variadic ") + op.toQualString(); + GRAPH_DUMP("Before " + pass_name, graph); + bool changed = VariadicUpdater(graph, op, variadic_op, list_idx).run(); + if (changed) { + GRAPH_DUMP("After " + pass_name, graph); + } + return changed; +} + +bool RemoveListMutationAndUseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx) { + bool changed_in_last_iter = true; + bool changed = false; + while (changed_in_last_iter) { + changed_in_last_iter = RemoveListMutation(graph); + changed_in_last_iter = + UseVariadicOp(graph, op, variadic_op, list_idx) || changed_in_last_iter; + changed = changed || changed_in_last_iter; + } + return changed; +} + +bool UseVariadicCat(const std::shared_ptr& graph) { + return UseVariadicOp(graph, aten::cat, prim::VarConcat); +} + +bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr& graph) { + return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat); +} + +bool UseVariadicStack(const std::shared_ptr& graph) { + return UseVariadicOp(graph, aten::stack, prim::VarStack); +} + +bool RemoveListMutationAndUseVariadicStack( + const std::shared_ptr& graph) { + return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/variadic_ops.h b/torch/csrc/jit/passes/variadic_ops.h new file mode 100644 index 0000000000000..e5f6a680c5039 --- /dev/null +++ b/torch/csrc/jit/passes/variadic_ops.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Replaces the `aten::cat` ops in the given graph with variadic cat ops. +// Returns true if the graph is modified. +TORCH_API bool UseVariadicCat(const std::shared_ptr& graph); + +TORCH_API bool RemoveListMutationAndUseVariadicCat( + const std::shared_ptr& graph); + +// Replaces the `aten::stack` ops in the given graph with variadic cat ops. +// Returns true if the graph is modified. +TORCH_API bool UseVariadicStack(const std::shared_ptr& graph); + +TORCH_API bool RemoveListMutationAndUseVariadicStack( + const std::shared_ptr& graph); + +TORCH_API bool UseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx = 0); + +TORCH_API bool RemoveListMutationAndUseVariadicOp( + const std::shared_ptr& graph, + NodeKind op, + NodeKind variadic_op, + size_t list_idx = 0); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 11210a4ea05b9..9b2cac6e25f9e 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -26,8 +26,8 @@ namespace { void replaceConv1dWithConv2d(std::shared_ptr& graph) { std::string conv_1d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): - %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) - return (%r) )"; + %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%res) )"; std::string conv_2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): @@ -47,8 +47,24 @@ void replaceConv1dWithConv2d(std::shared_ptr& graph) { %output : Tensor = aten::squeeze(%output_2d, %two) return (%output) )"; + std::vector> value_mappings( + {{"zero", "res"}, + {"one", "res"}, + {"stride_w", "res"}, + {"stride_2d", "res"}, + {"padding_w", "res"}, + {"padding_2d", "res"}, + {"dilation_w", "res"}, + {"dilation_2d", "res"}, + {"two", "res"}, + {"input_2d", "res"}, + {"weight_2d", "res"}, + {"output_2d", "res"}, + {"output", "res"}}); + SubgraphRewriter rewriter; - rewriter.RegisterRewritePattern(conv_1d_pattern, conv_2d_pattern); + rewriter.RegisterRewritePattern( + conv_1d_pattern, conv_2d_pattern, value_mappings); rewriter.runOnGraph(graph); } @@ -80,8 +96,8 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { std::string linear_before_inline = R"( graph(%linear, %input, %weight, %bias): - %r = prim::CallFunction(%linear, %input, %weight, %bias) - return (%r))"; + %res = prim::CallFunction(%linear, %input, %weight, %bias) + return (%res))"; std::string prepacked_ops_pattern_before_inline = R"( graph(%linear, %input, %weight, %bias): %output_min_max : None = prim::Constant() @@ -91,8 +107,8 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { return (%res))"; std::string linear_pattern = R"( graph(%input, %weight, %bias): - %r = aten::linear(%input, %weight, %bias) - return (%r))"; + %res = aten::linear(%input, %weight, %bias) + return (%res))"; std::string prepacked_ops_pattern = R"( graph(%input, %weight, %bias): %output_min_max : None = prim::Constant() @@ -112,13 +128,24 @@ void insertPrePackedLinearOp(std::shared_ptr& graph) { return false; }; + std::vector> value_mappings( + {{"output_min_max", "res"}, + {"packed_weight_bias", "res"}, + {"res", "res"}}); + SubgraphRewriter linear_call_fn_rewriter; linear_call_fn_rewriter.RegisterRewritePattern( - linear_before_inline, prepacked_ops_pattern_before_inline); + linear_before_inline, + prepacked_ops_pattern_before_inline, + value_mappings); linear_call_fn_rewriter.runOnGraph(graph, filter); + value_mappings = { + {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}}; + SubgraphRewriter linear_rewriter; - linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); + linear_rewriter.RegisterRewritePattern( + linear_pattern, prepacked_ops_pattern, value_mappings); linear_rewriter.runOnGraph(graph); } @@ -128,8 +155,8 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { std::string conv_2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): - %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) - return (%r) )"; + %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups) + return (%res) )"; std::string prepacked_ops_conv2d_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int): @@ -137,19 +164,24 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { %packed_weight_bias = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min_max, %output_min_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; + + std::vector> value_mappings( + {{"output_min_max", "res"}, + {"packed_weight_bias", "res"}, + {"res", "res"}}); SubgraphRewriter rewriter; rewriter.RegisterRewritePattern( - conv_2d_pattern, prepacked_ops_conv2d_pattern); + conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings); rewriter.runOnGraph(graph); std::string conv_2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): - %r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) - return (%r) )"; + %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) + return (%res) )"; std::string prepacked_ops_conv2d_transpose_pattern = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): @@ -157,12 +189,17 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack( %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, %output_min_max, %output_min_max) - %r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) + return (%res) )"; + + value_mappings = { + {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}}; SubgraphRewriter transpose_rewriter; transpose_rewriter.RegisterRewritePattern( - conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern); + conv_2d_transpose_pattern, + prepacked_ops_conv2d_transpose_pattern, + value_mappings); transpose_rewriter.runOnGraph(graph); } @@ -182,8 +219,8 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; std::string linear_prepack_run_hardtanh = R"( graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): @@ -193,8 +230,13 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %res = aten::hardtanh(%linear_res, %output_min, %output_max) return (%res))"; + std::vector> value_mappings( + {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}); + rewriter.RegisterRewritePattern( - linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused); + linear_prepack_run_hardtanh, + linear_prepack_run_hardtanh_fused, + value_mappings); std::string conv2d_prepack_run_hardtanh = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], @@ -203,11 +245,16 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::hardtanh(%conv2d_res, %output_min, %output_max) - return (%r) )"; + %res = aten::hardtanh(%conv2d_res, %output_min, %output_max) + return (%res) )"; + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; rewriter.RegisterRewritePattern( - conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused); + conv2d_prepack_run_hardtanh, + conv2d_prepack_run_hardtanh_fused, + value_mappings); std::string linear_prepack_run_hardtanh_inplace = R"( graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max): @@ -224,13 +271,24 @@ void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max) - return (%r) )"; + %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max) + return (%res) )"; + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; rewriter.RegisterRewritePattern( - linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused); + linear_prepack_run_hardtanh_inplace, + linear_prepack_run_hardtanh_fused, + value_mappings); + + value_mappings = { + {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}}; + rewriter.RegisterRewritePattern( - conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused); + conv2d_prepack_run_hardtanh_inplace, + conv2d_prepack_run_hardtanh_fused, + value_mappings); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } @@ -255,8 +313,8 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack( %weight, %bias, %stride, %padding, %dilation, %groups, %output_min, %output_max) - %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - return (%r) )"; + %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) + return (%res) )"; std::string linear_prepack_run_relu = R"( graph(%input, %weight, %bias, %dummy_min_max): @@ -266,8 +324,14 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %res = aten::relu(%linear_res) return (%res))"; + std::vector> value_mappings( + {{"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}); + rewriter.RegisterRewritePattern( - linear_prepack_run_relu, linear_prepack_run_relu_fused); + linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings); std::string conv2d_prepack_run_relu = R"( graph(%input, %weight, %bias, %stride:int[], %padding:int[], @@ -276,11 +340,17 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::relu(%conv2d_res) - return (%r) )"; + %res = aten::relu(%conv2d_res) + return (%res) )"; + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; rewriter.RegisterRewritePattern( - conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused); + conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings); std::string linear_prepack_run_relu_inplace = R"( graph(%input, %weight, %bias, %dummy_min_max): @@ -297,13 +367,30 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { %weight, %bias, %stride, %padding, %dilation, %groups, %dummy_min_max, %dummy_min_max) %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias) - %r = aten::relu_(%conv2d_res) - return (%r) )"; + %res = aten::relu_(%conv2d_res) + return (%res) )"; + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; rewriter.RegisterRewritePattern( - linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused); + linear_prepack_run_relu_inplace, + linear_prepack_run_relu_fused, + value_mappings); + + value_mappings = { + {"output_min", "packed_weight_bias"}, + {"output_max", "packed_weight_bias"}, + {"packed_weight_bias", "packed_weight_bias"}, + {"res", "res"}}; + rewriter.RegisterRewritePattern( - conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused); + conv2d_prepack_run_relu_inplace, + conv2d_prepack_run_relu_fused, + value_mappings); rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable); } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 5fca575593551..35197e4ea1423 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -589,6 +589,8 @@ void initJITBindings(PyObject* module) { .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU) .def("_jit_can_fuse_on_cpu", &canFuseOnCPU) .def("_jit_can_fuse_on_gpu", &canFuseOnGPU) + .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy) + .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy) .def( "_jit_differentiate", [](Graph& g) { @@ -711,8 +713,6 @@ void initJITBindings(PyObject* module) { .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed) .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled) .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled) - .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled) - .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled) .def( "_jit_set_te_generate_block_code", [](bool gen_block_code) { @@ -1280,11 +1280,15 @@ void initJITBindings(PyObject* module) { [](const FunctionSchema& self, const FunctionSchema& other) { return self == other; }) - .def("__str__", [](FunctionSchema& self) { - std::stringstream ss; - ss << self; - return ss.str(); - }); + .def( + "__str__", + [](FunctionSchema& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }) + .def_property_readonly( + "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); py::class_(m, "Argument") .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index f81632bc0fb0a..f8fae19ed8f50 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -89,6 +89,19 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type) : c10::ivalue::Tuple::create(std::move(values)); } + case TypeKind::UnionType: { + auto actual_type = toTypeInferredIValue(obj); + auto actual_type_ptr = actual_type.type(); + auto union_type = type->expect(); + if (!actual_type_ptr->isSubtypeOf(union_type)) { + throw py::cast_error(c10::str( + "Expected a member of ", + union_type->annotation_str(), + " but instead found type ", + actual_type.type()->annotation_str())); + } + return actual_type; + } case TypeKind::StringType: return ConstantString::create(py::cast(obj)); case TypeKind::DeviceObjType: { diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 0138231d3bc3f..eff1ddc243999 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -1151,7 +1151,7 @@ inline py::object invokeOperatorFromPython( Stack stack = std::get<1>(opWithStack); { pybind11::gil_scoped_release no_gil_guard; - found_op->getOperation()(&stack); + found_op->getOperation()(stack); } return createPyObjectForStack(std::move(stack)); diff --git a/torch/csrc/jit/python/python_interpreter.cpp b/torch/csrc/jit/python/python_interpreter.cpp index 82a0d22c54fa2..29b7929fcd690 100644 --- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -43,7 +43,7 @@ Operation createPythonOperation(const Node* op_) { AT_ASSERT(op->outputs().size() == 1); - return [=](Stack* stack) { + return [=](Stack& stack) { pybind11::gil_scoped_acquire gil; py::tuple py_inputs(op->cconv.size()); size_t i = 0; @@ -66,7 +66,7 @@ Operation createPythonOperation(const Node* op_) { drop(stack, num_inputs); try { py::object py_output(func(*py_inputs)); - stack->push_back(returnToIValue(op->output()->type(), py_output)); + stack.push_back(returnToIValue(op->output()->type(), py_output)); } catch (py::error_already_set& e) { throw std::runtime_error(e.what()); } diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index e0951c3ebbfbc..2c8246daec92b 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -869,6 +869,12 @@ void initPythonIRBindings(PyObject* module_) { } return types; }); + py::class_>(m, "UnionType") + .def(py::init( + [](const std::vector& a) { return UnionType::create(a); })) + .def("containedTypes", [](UnionType& self) { + return self.containedTypes().vec(); + }); py::class_>(m, "ListType") .def(py::init([](TypePtr a) { return ListType::create(a); })) .def_static("ofInts", &ListType::ofInts) diff --git a/torch/csrc/jit/runtime/calculate_necessary_args.h b/torch/csrc/jit/runtime/calculate_necessary_args.h index 5f37660ee14a8..07df670b01040 100644 --- a/torch/csrc/jit/runtime/calculate_necessary_args.h +++ b/torch/csrc/jit/runtime/calculate_necessary_args.h @@ -7,18 +7,42 @@ namespace torch { namespace jit { -inline size_t CalculateNecessaryArgs( +inline std::pair CalculateNecessaryArgs( const std::vector& schema_args, - at::ArrayRef actual_inputs) { + at::ArrayRef actual_inputs, + bool allow_trailing_out_args) { + if (schema_args.size() == 0) { + return std::make_pair(0, 0); + } + + // count number of out arguments + auto schema_idx = schema_args.size() - 1; + if (allow_trailing_out_args) { + // skip over out arguments in the end. + while (schema_idx >= 0) { + auto current_arg = schema_args.at(schema_idx); + if (!current_arg.is_out()) { + break; + } + schema_idx--; + } + } + + auto num_out = schema_args.size() - schema_idx - 1; + if (schema_args.size() < actual_inputs.size()) { - return actual_inputs.size(); + return std::make_pair(actual_inputs.size(), num_out); + } + + // if it is the default args, we reset the index to the last element + if (!allow_trailing_out_args) { + schema_idx = schema_args.size() - 1; } // keeps track of trailing unnecessary args - int schema_size = schema_args.size(); - for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) { + while (schema_idx >= 0) { // this means it is not default argument, so it is necessary if (!schema_args.at(schema_idx).default_value().has_value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } else { auto schema_value = schema_args.at(schema_idx).default_value().value().toIValue(); @@ -27,16 +51,17 @@ inline size_t CalculateNecessaryArgs( // well. auto actual_value = toIValue(actual_inputs[schema_idx]); if (!actual_value.has_value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } // if the IR has same value as default value of the schema, // it is not neccessary argument. if (schema_value != actual_value.value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } } + schema_idx--; } - return 0; + return std::make_pair(0, num_out); } } // namespace jit diff --git a/torch/csrc/jit/runtime/custom_operator.h b/torch/csrc/jit/runtime/custom_operator.h index 45ad6676376ce..e39789bfe9da3 100644 --- a/torch/csrc/jit/runtime/custom_operator.h +++ b/torch/csrc/jit/runtime/custom_operator.h @@ -19,7 +19,7 @@ struct TORCH_API RegisterOperators { /// Registers a vector of already created `Operator`s. /// The operator element is now optional to filter null ops. It's backward /// compatible and works for selective operator registration. - RegisterOperators(std::vector> operators) { + explicit RegisterOperators(std::vector> operators) { for (c10::optional& o : operators) { if (o) { registerOperator(std::move(o.value())); diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index 476882650a1dd..39742c7815d3b 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -377,7 +377,7 @@ struct DifferentiableGraphOp { num_outputs(this->grad.f->outputs().size()) {} // XXX: keep in mind that stack can be larger than the inputs we need! - void operator()(Stack* stack) const { + void operator()(Stack& stack) const { auto grad_fn = std::make_shared( grad_executor, grad.df_input_vjps.size(), @@ -394,13 +394,13 @@ struct DifferentiableGraphOp { captureInputs(*grad_fn, inputs); } - detachVariables(*stack); + detachVariables(stack); if (IsNewExecutorEnabled()) { ExecutionPlan plan = - f_ptr->getPlanFor(*stack, GraphExecutor::getDefaultNumBailOuts()); - InterpreterState(plan.code).run(*stack); + f_ptr->getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()); + InterpreterState(plan.code).run(stack); } else { - InterpreterState(legacy_f).run(*stack); + InterpreterState(legacy_f).run(stack); } { @@ -419,7 +419,7 @@ struct DifferentiableGraphOp { // drop the temporary outputs so that we return the same number of // outputs as if we were not also calculating gradient const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs; - stack->erase(stack->end() - num_temporary_outputs, stack->end()); + stack.erase(stack.end() - num_temporary_outputs, stack.end()); } } @@ -908,7 +908,7 @@ void runNondiffOptimization( void runOptimization( std::shared_ptr& graph, - bool unroll, + bool unroll_non_constant_loops, bool const_prop_user_classes) { // Basic graph preprocessing to eliminate noise. GRAPH_DEBUG( @@ -935,9 +935,17 @@ void runOptimization( // Unroll small loops, and eliminate expressions that are the same at every // iteration. - if (unroll) { - UnrollLoops(graph); + bool unroll_success = false; + if (unroll_non_constant_loops) { + unroll_success = UnrollLoops(graph); GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph); + } else { + unroll_success = UnrollConstantLoops(graph); + GRAPH_DEBUG( + "After UnrollConstantLoops, before RemoveListMutation\n", *graph); + } + + if (unroll_success) { // run again with unrolled loops RemoveListMutation(graph); GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index 516ad1f55c812..3815d26c87f4d 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -33,7 +33,7 @@ void packGradient(const Gradient& gradient, Node* dnode); bool needsGradient(const std::shared_ptr& graph); void runOptimization( std::shared_ptr& graph, - bool unroll = true, + bool unroll_non_constant_loops = true, bool const_prop_user_classes = true); void runNondiffOptimization( std::shared_ptr& graph, diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index be2019e532f98..b34827176b2f3 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -297,13 +297,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } case INST(OP): { INST_GUARD; - frame.function->operator_table_[inst.X](&stack); + frame.function->operator_table_[inst.X](stack); } INST_NEXT; case INST(OPN): { INST_GUARD; stack.push_back(inst.N); - frame.function->operator_table_[inst.X](&stack); + frame.function->operator_table_[inst.X](stack); } INST_NEXT; case INST(LOAD): { @@ -978,11 +978,13 @@ MobileCode::MobileCode( const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, + bool support_default_args_before_out, size_t remaining_bailout_depth) : Code(new interpreter::MobileCodeImpl( graph, std::move(function_name), emit_default_input_instructions, + support_default_args_before_out, remaining_bailout_depth)) {} MobileCode::~MobileCode() = default; diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 80720ea2ca42f..3471e558e5a41 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -82,6 +82,7 @@ struct TORCH_API MobileCode : Code { const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions = true, + bool support_default_args_before_out = false, size_t remaining_bailout_depth = 0); ~MobileCode(); }; diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 00648de905767..15ba0cec04d33 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -105,6 +105,8 @@ struct CodeImpl { // This is because for all usages, at most 3 args are used. std::unordered_map op_to_num_specified_args_; + std::unordered_map op_to_num_out_args_; + // running count of uses as we emit. When we reach use_count_[v] = // v.uses().size() we know it is the final use and we can move rather than // load. @@ -292,6 +294,12 @@ struct CodeImpl { } } + void emitLoadInputs(at::ArrayRef inputs, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + emitUse(inputs[i], false); + } + } + virtual void emitOperator(Node* node) { emitLoadInputs(node->inputs()); const Operator& op = node->getOperator(); @@ -713,9 +721,11 @@ struct MobileCodeImpl : CodeImpl { const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, + bool support_default_args_before_out, size_t remaining_bailout_depth) : CodeImpl(graph, function_name, remaining_bailout_depth, false), - emit_default_input_instructions_(emit_default_input_instructions) { + emit_default_input_instructions_(emit_default_input_instructions), + support_default_args_before_out_(support_default_args_before_out) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) run(); } @@ -737,13 +747,20 @@ struct MobileCodeImpl : CodeImpl { auto op_schema = node->getOperator().schema(); // skip if schema has vararg if (!op_schema.is_vararg()) { - auto numInclude = - CalculateNecessaryArgs(op_schema.arguments(), node->inputs()); + auto specifiedArgs = CalculateNecessaryArgs( + op_schema.arguments(), + node->inputs(), + support_default_args_before_out_); + + size_t numInclude = specifiedArgs.first + + (support_default_args_before_out_ ? specifiedArgs.second : 0); auto unique_name = op_schema.overload_name() != "" ? op_schema.name() + "." + op_schema.overload_name() : op_schema.name(); auto it = op_to_num_specified_args_.insert( std::pair(unique_name, 0)); + op_to_num_out_args_.insert(std::pair( + unique_name, specifiedArgs.second)); auto prev_value = it.first->second; it.first->second = std::max(numInclude, prev_value); } @@ -768,14 +785,27 @@ struct MobileCodeImpl : CodeImpl { if (it != op_to_num_specified_args_.end()) { num_include = it->second; } - emitLoadInputs(node->inputs(), num_include); + if (support_default_args_before_out_) { + auto num_out = op_to_num_out_args_.find(unique_op_name)->second; + auto num_specified_before_out = num_include - num_out; + emitLoadInputs(node->inputs(), 0, num_specified_before_out); + emitLoadInputs( + node->inputs(), + node->inputs().size() - num_out, + node->inputs().size()); + } else { + emitLoadInputs(node->inputs(), num_include); + } insertInstruction(OP, operator_table_.size()); } operator_table_.emplace_back(op.getOperation(node)); } } + // To support forward compatibility for bytecode version bump from v5 to v6 bool emit_default_input_instructions_; + // To support forward compatibility for bytecode version bump from v6 to v7 + bool support_default_args_before_out_; }; } // namespace interpreter diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index e243e8ff57f2d..ccdbfa03f5e8c 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -220,13 +220,24 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); // string. template c10::optional OperatorGenerator( - torch::detail::SelectiveStr schema_str, + const char* schema_str, Func&& op, AliasAnalysisKind alias_analysis) { return c10::optional(Operator( std::string(schema_str), std::forward(op), alias_analysis)); } +template +c10::optional OperatorGenerator( + torch::detail::SelectiveStr schema_str, + Func&& op, + AliasAnalysisKind alias_analysis) { + return OperatorGenerator( + static_cast(schema_str), + std::forward(op), + alias_analysis); +} + template c10::optional OperatorGenerator( torch::detail::SelectiveStr schema_str, diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index 993d41194e84b..4d541ec46bbbf 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -12,7 +12,7 @@ namespace jit { namespace { Operator createOperatorFromC10(const c10::OperatorHandle& op) { - return Operator(op, [op](Stack* stack) { op.callBoxed(stack); }); + return Operator(op, [op](Stack& stack) { op.callBoxed(stack); }); } class RegistrationListener final : public c10::OpRegistrationListener { diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp index f7a989d7acef9..599fd5398c370 100644 --- a/torch/csrc/jit/runtime/register_cuda_ops.cpp +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -38,7 +38,7 @@ void _device_synchronize(int64_t device_index) { RegisterOperators const reg({ Operator( "cuda::current_stream.device(Device? device) -> __torch__.torch.classes.cuda.Stream", - [](Stack* stack) { + [](Stack& stack) { auto device = pop(stack).toOptional(); c10::DeviceIndex device_index = device.has_value() ? device->index() @@ -50,7 +50,7 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::current_stream.int(int? val) -> __torch__.torch.classes.cuda.Stream", - [](Stack* stack) { + [](Stack& stack) { auto idx = pop(stack).toOptional(); c10::DeviceIndex device_index = idx.has_value() ? static_cast(idx.value()) @@ -62,7 +62,7 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::default_stream.device(Device? device) -> __torch__.torch.classes.cuda.Stream", - [](Stack* stack) { + [](Stack& stack) { auto device = pop(stack).toOptional(); c10::DeviceIndex device_index = device.has_value() ? device->index() @@ -74,7 +74,7 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::default_stream.int(int? val) -> __torch__.torch.classes.cuda.Stream", - [](Stack* stack) { + [](Stack& stack) { auto idx = pop(stack).toOptional(); c10::DeviceIndex device_index = idx.has_value() ? static_cast(idx.value()) @@ -86,14 +86,14 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::_current_device() -> int", - [](Stack* stack) { + [](Stack& stack) { auto v = c10::cuda::current_device(); push(stack, static_cast(v)); }, aliasAnalysisFromSchema()), Operator( "cuda::_set_device(int64_t val) -> ()", - [](Stack* stack) { + [](Stack& stack) { int64_t idx = -1; pop(stack, idx); c10::cuda::set_device(static_cast(idx)); @@ -101,7 +101,7 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::device_index(Device device) -> int", - [](Stack* stack) { + [](Stack& stack) { auto device = pop(stack); auto idx = device.toDevice().index(); push(stack, idx); @@ -109,11 +109,11 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::device_count() -> int", - [](Stack* stack) { push(stack, at::cuda::device_count()); }, + [](Stack& stack) { push(stack, at::cuda::device_count()); }, aliasAnalysisFromSchema()), Operator( "cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()", - [](Stack* stack) { + [](Stack& stack) { auto v = pop(stack); auto s = v.toCustomClass(); auto stream_device_idx = static_cast(s->device_index()); @@ -141,11 +141,11 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::synchronize() -> ()", - [](Stack* stack) { c10::cuda::device_synchronize(); }, + [](Stack& stack) { c10::cuda::device_synchronize(); }, aliasAnalysisFromSchema()), Operator( "cuda::synchronize.device(Device? device) -> ()", - [](Stack* stack) { + [](Stack& stack) { auto device = pop(stack).toOptional(); c10::DeviceIndex device_index = device.has_value() ? device->index() @@ -155,7 +155,7 @@ RegisterOperators const reg({ aliasAnalysisFromSchema()), Operator( "cuda::synchronize.int(int? val) -> ()", - [](Stack* stack) { + [](Stack& stack) { auto idx = pop(stack).toOptional(); c10::DeviceIndex device_index = idx.has_value() ? static_cast(idx.value()) diff --git a/torch/csrc/jit/runtime/register_distributed_ops.cpp b/torch/csrc/jit/runtime/register_distributed_ops.cpp index 2c8277d106f3c..edf7a0ccff23a 100644 --- a/torch/csrc/jit/runtime/register_distributed_ops.cpp +++ b/torch/csrc/jit/runtime/register_distributed_ops.cpp @@ -29,11 +29,11 @@ static auto workerInfo = // prepare the rpc input arguments and call the C++ impls void prepare_and_call_rpc_op( - Stack* stack, + Stack& stack, int num_inputs, const std::string& rpc_op) { // Get inputs from the stack. - auto stackIter = stack->end() - num_inputs; + auto stackIter = stack.end() - num_inputs; auto& dstWorkerIValue = *stackIter++; auto& qualifiedNameIValue = *stackIter++; IValue emptyTuple(c10::ivalue::Tuple::create({})); @@ -137,7 +137,7 @@ void prepare_and_call_rpc_op( rpcTimeout); // Push output to the stack. drop(stack, num_inputs); - stack->emplace_back(std::move(futureIValuePtr)); + stack.emplace_back(std::move(futureIValuePtr)); } else if (rpc_op == "rpc_sync") { // Send RPC request. auto futureIValuePtr = dist_rpc::rpcTorchscript( @@ -154,7 +154,7 @@ void prepare_and_call_rpc_op( auto res = futureIValuePtr->value(); // Push output to the stack. drop(stack, num_inputs); - stack->emplace_back(std::move(res)); + stack.emplace_back(std::move(res)); } } else if (rpc_op == "rpc_remote") { auto rrefPtr = dist_rpc::remoteTorchscript( @@ -165,7 +165,7 @@ void prepare_and_call_rpc_op( rpcTimeout); // Push output to the stack. drop(stack, num_inputs); - stack->emplace_back( + stack.emplace_back( c10::static_intrusive_pointer_cast(rrefPtr)); } else { throw std::runtime_error( @@ -178,7 +178,7 @@ RegisterOperators reg_rpc_ops( fmt::format( "aten::to_here(RRef(t) self, float timeout = {}) -> t(*)", torch::distributed::rpc::kDefaultRpcTimeoutSeconds), - [](Stack* stack) { + [](Stack& stack) { auto timeout = pop(stack).toDouble(); auto rref = pop(stack).toRRef(); IValue res; @@ -195,7 +195,7 @@ RegisterOperators reg_rpc_ops( aliasAnalysisFromSchema()), Operator( "aten::local_value(RRef(t) self) -> t(*)", - [](Stack* stack) { + [](Stack& stack) { auto rref = pop(stack).toRRef(); TORCH_CHECK( rref->isOwner(), @@ -208,14 +208,14 @@ RegisterOperators reg_rpc_ops( aliasAnalysisFromSchema()), Operator( "aten::is_owner(RRef(t) self) -> bool", - [](Stack* stack) { + [](Stack& stack) { auto rref = pop(stack).toRRef(); push(stack, rref->isOwner()); }, aliasAnalysisFromSchema()), Operator( "aten::owner(RRef(t) self) -> __torch__.torch.classes.dist_rpc.WorkerInfo", - [](Stack* stack) { + [](Stack& stack) { auto rref = pop(stack).toRRef(); push( stack, @@ -225,21 +225,21 @@ RegisterOperators reg_rpc_ops( aliasAnalysisFromSchema()), Operator( "aten::owner_name(RRef(t) self) -> str", - [](Stack* stack) { + [](Stack& stack) { auto rref = pop(stack).toRRef(); push(stack, rref->ownerName()); }, aliasAnalysisFromSchema()), Operator( "aten::confirmed_by_owner(RRef(t) self) -> bool", - [](Stack* stack) { + [](Stack& stack) { auto rref = pop(stack).toRRef(); push(stack, rref->confirmedByOwner()); }, aliasAnalysisFromSchema()), Operator( "aten::dist_backward(int context_id, Tensor[] roots, bool retain_graph=False) -> ()", - [](Stack* stack) { + [](Stack& stack) { bool retain_graph = pop(stack).toBool(); auto roots_list = pop(stack).toTensorList(); int64_t context_id = pop(stack).toInt(); @@ -252,7 +252,7 @@ RegisterOperators reg_rpc_ops( prim::rpc_sync, [](const Node* node) -> Operation { int num_inputs = node->inputs().size(); - return [num_inputs](Stack* stack) { + return [num_inputs](Stack& stack) { prepare_and_call_rpc_op(stack, num_inputs, "rpc_sync"); }; }, @@ -261,7 +261,7 @@ RegisterOperators reg_rpc_ops( prim::rpc_remote, [](const Node* node) -> Operation { int num_inputs = node->inputs().size(); - return [num_inputs](Stack* stack) { + return [num_inputs](Stack& stack) { prepare_and_call_rpc_op(stack, num_inputs, "rpc_remote"); }; }, @@ -270,7 +270,7 @@ RegisterOperators reg_rpc_ops( prim::rpc_async, [](const Node* node) -> Operation { int num_inputs = node->inputs().size(); - return [num_inputs](Stack* stack) { + return [num_inputs](Stack& stack) { prepare_and_call_rpc_op(stack, num_inputs, "rpc_async"); }; }, diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index 91ff2c738a1bf..64bb3abc57584 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -13,7 +13,7 @@ c10::impl::GenericList make_result_list(const TypePtr& elemType) { } template <> -void listIndex(Stack* stack) { +void listIndex(Stack& stack) { at::Tensor elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -31,7 +31,7 @@ void listIndex(Stack* stack) { } template <> -void listCount(Stack* stack) { +void listCount(Stack& stack) { at::Tensor elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -44,21 +44,21 @@ void listCount(Stack* stack) { } template <> -void listEq(Stack* stack) { +void listEq(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); push(stack, tensor_list_equal(a, b)); } template <> -void listNe(Stack* stack) { +void listNe(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); push(stack, !tensor_list_equal(a, b)); } template <> -void listSort(Stack* stack) { +void listSort(Stack& stack) { bool reverse = pop(stack).toBool(); c10::List list = pop(stack).toTensorList(); std::sort( @@ -74,7 +74,7 @@ void listSort(Stack* stack) { } template <> -void listCopyAndSort(Stack* stack) { +void listCopyAndSort(Stack& stack) { c10::List list = pop(stack).toTensorList(); auto list_copied = list.copy(); std::sort( @@ -87,7 +87,7 @@ void listCopyAndSort(Stack* stack) { } template <> -void listRemove(Stack* stack) { +void listRemove(Stack& stack) { at::Tensor elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -268,7 +268,7 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) { return idx; } -void listAppend(Stack* stack) { +void listAppend(Stack& stack) { IValue el = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -276,13 +276,13 @@ void listAppend(Stack* stack) { push(stack, std::move(list)); } -void listReverse(Stack* stack) { +void listReverse(Stack& stack) { c10::List list = pop(stack).to>(); std::reverse(list.begin(), list.end()); } -void listPopImpl(Stack* stack, const char* empty_message) { +void listPopImpl(Stack& stack, const char* empty_message) { int64_t idx = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -297,22 +297,22 @@ void listPopImpl(Stack* stack, const char* empty_message) { list.erase(list.begin() + normalized_idx); } -void listPop(Stack* stack) { +void listPop(Stack& stack) { return listPopImpl(stack, "pop from empty list"); } -void listClear(Stack* stack) { +void listClear(Stack& stack) { c10::List list = pop(stack).to>(); list.clear(); } -void listDelete(Stack* stack) { +void listDelete(Stack& stack) { listPopImpl(stack, "pop index out of range"); pop(stack); } -void listInsert(Stack* stack) { +void listInsert(Stack& stack) { IValue elem = pop(stack).to(); int64_t idx = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -331,7 +331,7 @@ void listInsert(Stack* stack) { } } -void listExtend(Stack* stack) { +void listExtend(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); @@ -341,12 +341,12 @@ void listExtend(Stack* stack) { } } -void listCopy(Stack* stack) { +void listCopy(Stack& stack) { c10::List list = pop(stack).to>(); push(stack, list.copy()); } -void listSelect(Stack* stack) { +void listSelect(Stack& stack) { int64_t idx = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -354,19 +354,19 @@ void listSelect(Stack* stack) { push(stack, std::move(element)); } -void listLen(Stack* stack) { +void listLen(Stack& stack) { c10::List a = pop(stack).to>(); const int64_t size = a.size(); push(stack, size); } -void listList(Stack* stack) { +void listList(Stack& stack) { c10::List a = pop(stack).to>(); push(stack, a.copy()); } -void listAdd(Stack* stack) { +void listAdd(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); @@ -383,14 +383,14 @@ void listAdd(Stack* stack) { push(stack, std::move(ret)); } -void listInplaceAdd(Stack* stack) { +void listInplaceAdd(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); a.append(std::move(b)); push(stack, std::move(a)); } -void listMulIntLeftInPlace(Stack* stack) { +void listMulIntLeftInPlace(Stack& stack) { int64_t n = pop(stack).to(); c10::List list = pop(stack).to>(); if (n <= 0) { @@ -408,7 +408,7 @@ void listMulIntLeftInPlace(Stack* stack) { push(stack, std::move(list)); } -void listMulIntLeft(Stack* stack) { +void listMulIntLeft(Stack& stack) { int64_t n = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -426,7 +426,7 @@ void listMulIntLeft(Stack* stack) { push(stack, std::move(ret)); } -void listMulIntRight(Stack* stack) { +void listMulIntRight(Stack& stack) { c10::List list = pop(stack).to>(); int64_t n = pop(stack).to(); @@ -444,7 +444,7 @@ void listMulIntRight(Stack* stack) { push(stack, std::move(ret)); } -void listSlice(Stack* stack) { +void listSlice(Stack& stack) { auto step_val = pop(stack); auto end_val = pop(stack); auto start_val = pop(stack); @@ -477,7 +477,7 @@ void listSlice(Stack* stack) { push(stack, std::move(sliced_list)); } -void listSetItem(Stack* stack) { +void listSetItem(Stack& stack) { IValue value = pop(stack).to(); int64_t idx = pop(stack).to(); c10::List list = pop(stack).to>(); diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index e068b7877aff1..a4efb67943569 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -35,15 +35,15 @@ namespace torch { namespace jit { -inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { +constexpr inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { return c10::AliasAnalysisKind::FROM_SCHEMA; } -inline c10::AliasAnalysisKind aliasAnalysisConservative() { +constexpr inline c10::AliasAnalysisKind aliasAnalysisConservative() { return c10::AliasAnalysisKind::CONSERVATIVE; } -inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() { +constexpr inline c10::AliasAnalysisKind aliasAnalysisSpecialCase() { return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } @@ -55,7 +55,7 @@ c10::List make_result_list(const TypePtr& elemType) { template <> c10::impl::GenericList make_result_list(const TypePtr& elemType); -inline void noop(Stack* n) {} +inline void noop(Stack& n) {} // As described in https://docs.python.org/3/library/functions.html#round // When a number is exactly halfway between two integers, python builtin round @@ -181,12 +181,12 @@ void setItem(const c10::List& list, int64_t idx, T&& value) { list.set(normalized_idx, std::forward(value)); } -void listAppend(Stack* stack); +void listAppend(Stack& stack); -void listReverse(Stack* stack); +void listReverse(Stack& stack); template -void minList(Stack* stack) { +void minList(Stack& stack) { c10::List a = pop(stack).to>(); c10::List b = pop(stack).to>(); @@ -204,7 +204,7 @@ void minList(Stack* stack) { } template -void maxList(Stack* stack) { +void maxList(Stack& stack) { c10::List a = pop(stack).to>(); c10::List b = pop(stack).to>(); @@ -221,18 +221,18 @@ void maxList(Stack* stack) { push(stack, b.size() > a.size() ? b : a); } -void listPopImpl(Stack* stack, const char* empty_message); +void listPopImpl(Stack& stack, const char* empty_message); -void listPop(Stack* stack); +void listPop(Stack& stack); -void listClear(Stack* stack); +void listClear(Stack& stack); -void listDelete(Stack* stack); +void listDelete(Stack& stack); -void listInsert(Stack* stack); +void listInsert(Stack& stack); template -void listRemove(Stack* stack) { +void listRemove(Stack& stack) { T elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -246,7 +246,7 @@ void listRemove(Stack* stack) { } template -void listMin(Stack* stack) { +void listMin(Stack& stack) { c10::List list = pop(stack).to>(); size_t list_size = list.size(); if (list_size == 0) { @@ -259,11 +259,11 @@ void listMin(Stack* stack) { min_elem = elem < min_elem ? elem : min_elem; } - stack->push_back(min_elem); + stack.push_back(min_elem); } template -void listMax(Stack* stack) { +void listMax(Stack& stack) { c10::List list = pop(stack).to>(); size_t list_size = list.size(); if (list_size == 0) { @@ -276,14 +276,14 @@ void listMax(Stack* stack) { max_elem = elem > max_elem ? elem : max_elem; } - stack->push_back(max_elem); + stack.push_back(max_elem); } template <> -void listRemove(Stack* stack); +void listRemove(Stack& stack); template -void listIndex(Stack* stack) { +void listIndex(Stack& stack) { T elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -297,10 +297,10 @@ void listIndex(Stack* stack) { } template <> -void listIndex(Stack* stack); +void listIndex(Stack& stack); template -void listCount(Stack* stack) { +void listCount(Stack& stack) { T elem = pop(stack).to(); c10::List list = pop(stack).to>(); @@ -309,25 +309,25 @@ void listCount(Stack* stack) { } template <> -void listCount(Stack* stack); +void listCount(Stack& stack); -void listExtend(Stack* stack); +void listExtend(Stack& stack); -void listCopy(Stack* stack); +void listCopy(Stack& stack); -void listSelect(Stack* stack); +void listSelect(Stack& stack); -void listLen(Stack* stack); +void listLen(Stack& stack); template -void listEq(Stack* stack) { +void listEq(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); push(stack, a == b); } template -void listNe(Stack* stack) { +void listNe(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); push(stack, a != b); @@ -357,16 +357,16 @@ inline bool tensor_list_equal( // Specialization for at::Tensor, since it doesn't define operator== template <> -void listEq(Stack* stack); +void listEq(Stack& stack); // Specialization for at::Tensor, since it doesn't define operator== template <> -void listNe(Stack* stack); +void listNe(Stack& stack); -void listList(Stack* stack); +void listList(Stack& stack); template -void listContains(Stack* stack) { +void listContains(Stack& stack) { auto key = pop(stack).to(); auto list = pop(stack).to>(); // NOLINTNEXTLINE(performance-implicit-conversion-in-loop) @@ -379,20 +379,20 @@ void listContains(Stack* stack) { push(stack, false); } -void listAdd(Stack* stack); +void listAdd(Stack& stack); -void listInplaceAdd(Stack* stack); +void listInplaceAdd(Stack& stack); -void listMulIntLeftInPlace(Stack* stack); +void listMulIntLeftInPlace(Stack& stack); -void listMulIntLeft(Stack* stack); +void listMulIntLeft(Stack& stack); -void listMulIntRight(Stack* stack); +void listMulIntRight(Stack& stack); -void listSlice(Stack* stack); +void listSlice(Stack& stack); template -void listSort(Stack* stack) { +void listSort(Stack& stack) { bool reverse = pop(stack).toBool(); c10::List list = pop(stack).to>(); std::sort(list.begin(), list.end(), [reverse](const T& a, const T& b) { @@ -408,10 +408,10 @@ void listSort(Stack* stack) { // Specialization for at::Tensor template <> -void listSort(Stack* stack); +void listSort(Stack& stack); template -void listCopyAndSort(Stack* stack) { +void listCopyAndSort(Stack& stack) { c10::List list = pop(stack).to>(); auto list_copied = list.copy(); std::sort(list_copied.begin(), list_copied.end(), [](const T& a, const T& b) { @@ -426,36 +426,73 @@ void listCopyAndSort(Stack* stack) { // Specialization for at::Tensor template <> -void listCopyAndSort(Stack* stack); - -void listSetItem(Stack* stack); +void listCopyAndSort(Stack& stack); + +void listSetItem(Stack& stack); + +struct OperatorGeneratorArgs { + const char* schema_str; + bool isOperationCreator; + union { + void (*operation)(Stack&); + OperationCreator operationCreator; + }; + AliasAnalysisKind aliasAnalysis; + + explicit constexpr OperatorGeneratorArgs( + torch::detail::SelectiveStr schema_str, + void (*op)(Stack&), + AliasAnalysisKind aa) + : schema_str(schema_str), + isOperationCreator(false), + operation(op), + aliasAnalysis(aa) {} + + explicit constexpr OperatorGeneratorArgs( + torch::detail::SelectiveStr schema_str, + OperationCreator opCreator, + AliasAnalysisKind aa) + : schema_str(schema_str), + isOperationCreator(true), + operationCreator(opCreator), + aliasAnalysis(aa) {} + + template + explicit constexpr OperatorGeneratorArgs( + torch::detail::SelectiveStr, + Args...) + : schema_str(nullptr), + isOperationCreator(false), + operation(nullptr), + aliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE) {} +}; #define DEFINE_GENERIC_BINARY_OP( \ aten_op, op, int_float_result, complex_result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op \ ".int_int(int a, int b) -> " #int_float_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a, b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op \ ".float_float(float a, float b) -> " #int_float_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a, b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op \ ".complex_complex(complex a, complex b) -> " #complex_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c10::complex a, b; \ pop(stack, a, b); \ push(stack, op); \ @@ -464,18 +501,18 @@ void listSetItem(Stack* stack); // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a, b; \ pop(stack, a, b); \ push(stack, int_op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".float(float a, float b) -> " #float_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a, b; \ pop(stack, a, b); \ push(stack, float_op); \ @@ -483,20 +520,20 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_INT_FLOAT_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op \ ".int_float(int a, float b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a; \ double b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op \ ".float_int(float a, int b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a; \ int64_t b; \ pop(stack, a, b); \ @@ -505,9 +542,9 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_INT_OP(aten_op, op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a, b; \ pop(stack, a, b); \ push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \ @@ -515,9 +552,9 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_STR_CMP_OP(aten_op, op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ auto b = pop(stack).toStringRef(); \ auto a = pop(stack).toStringRef(); \ push(stack, op); \ @@ -530,10 +567,10 @@ void listSetItem(Stack* stack); // in unintended implicit conversions #define DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION_GENERIC( \ aten_op, int_op, float_op, result, string_val) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op string_val \ "(Scalar a, Scalar b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ IValue x, y; \ pop(stack, x, y); \ if (x.isDouble()) { \ @@ -586,9 +623,9 @@ void listSetItem(Stack* stack); DEFINE_STR_CMP_OP(aten_op, op) #define DEFINE_UNARY_INT_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a; \ pop(stack, a); \ push(stack, op); \ @@ -596,9 +633,9 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_UNARY_FLOAT_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a; \ pop(stack, a); \ push(stack, op); \ @@ -608,9 +645,9 @@ void listSetItem(Stack* stack); #define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \ DEFINE_UNARY_INT_OP(aten_op, op, int_result), \ DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ IValue x; \ pop(stack, x); \ if (x.isDouble()) { \ @@ -623,18 +660,18 @@ void listSetItem(Stack* stack); }, \ aliasAnalysisFromSchema()) #define DEFINE_BOOL_OP(aten_op, op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ bool a, b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()) #define DEFINE_STRING_OP(op_name, string_op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ auto b = pop(stack).toStringRef(); \ auto a = pop(stack).toStringRef(); \ push(stack, string_op); \ @@ -646,9 +683,9 @@ void listSetItem(Stack* stack); //----------------------------------------------------------------------------- //----------------------------------------------------------------------------- #define DEFINE_UNARY_COMPLEX_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".complex(complex a) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c10::complex a; \ pop(stack, a); \ push(stack, op); \ @@ -670,9 +707,9 @@ void listSetItem(Stack* stack); DEFINE_UNARY_INT_OP(aten_op, op, int_result), \ DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \ DEFINE_UNARY_COMPLEX_OP(aten_op, op, complex_result), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ IValue x; \ pop(stack, x); \ if (x.isDouble()) { \ @@ -700,27 +737,27 @@ void listSetItem(Stack* stack); int_result, \ float_result, \ complex_result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a, b; \ pop(stack, a, b); \ push(stack, int_op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".complex(complex a, complex b) -> " #complex_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c10::complex a, b; \ pop(stack, a, b); \ push(stack, complex_op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".float(float a, float b) -> " #float_result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a, b; \ pop(stack, a, b); \ push(stack, float_op); \ @@ -728,20 +765,20 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_INT_COMPLEX_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op \ ".int_complex(int a, complex b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ int64_t a; \ c10::complex b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".complex_int(complex a, int b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c10::complex a; \ int64_t b; \ pop(stack, a, b); \ @@ -750,20 +787,20 @@ void listSetItem(Stack* stack); aliasAnalysisFromSchema()) #define DEFINE_FLOAT_COMPLEX_OP(aten_op, op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".float_complex(float a, complex b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ double a; \ c10::complex b; \ pop(stack, a, b); \ push(stack, op); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ #aten_op ".complex_float(complex a, float b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c10::complex a; \ double b; \ pop(stack, a, b); \ @@ -773,10 +810,10 @@ void listSetItem(Stack* stack); #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_AVOID_COLLISION_GENERIC( \ aten_op, int_op, float_op, complex_op, result, string_val) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op string_val \ "(Scalar a, Scalar b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ IValue x, y; \ pop(stack, x, y); \ if (x.isComplexDouble()) { \ @@ -821,9 +858,9 @@ void listSetItem(Stack* stack); #define DEFINE_SCALAR_BINARY_OP_WITH_COMPLEX_WITHOUT_INT_COMPLEX_PAIR( \ aten_op, int_op, float_op, complex_op, result) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \ - [](Stack* stack) { \ + [](Stack& stack) { \ IValue x, y; \ pop(stack, x, y); \ if (x.isComplexDouble()) { \ diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index a61cb48b1ddce..9164471dfddf7 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -86,935 +86,862 @@ auto powWrapper(T a, U b) { return pow(a, b); } -RegisterOperators reg( - {OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::str(t elem) -> str"), - [](Stack* stack) { - std::stringstream ss; - ss << pop(stack); - push(stack, ss.str()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::list(str t) -> str[]"), - [](Stack* stack) { - auto str = pop(stack).toStringRef(); - c10::List chars; - chars.reserve(str.size()); - for (auto c : str) { - chars.push_back(std::string(1, c)); - } - push(stack, std::move(chars)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::cpu(Tensor(a) self) -> Tensor(a|b)"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.cpu()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::layout(Tensor a) -> int"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.layout()); - }, - aliasAnalysisFromSchema()), - Operator( - prim::tolist, - // This operator has to be unschematized because the return type - // depends on the type hint and input. The implementation of this - // operator below is intended to be as close to the Python - // implementation in torch/csrc/utils/tensor_list.cpp as possible. - [](const Node* /*node*/) -> Operation { - return [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int elem_ty_val; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int dim_val; - at::Tensor t; - - pop(stack, elem_ty_val); - pop(stack, dim_val); - pop(stack, t); - - // If the Tensor is not on the CPU, transfer it. - if (!t.device().is_cpu()) { - t = t.cpu(); - } - - // Rebuild the output type using elem_ty_val and dim_val. Start - // with the element type corresponding to elem_ty_val. - TypePtr out_ty; - if (elem_ty_val == 0) { - out_ty = IntType::get(); - } else if (elem_ty_val == 1) { - out_ty = FloatType::get(); - } else if (elem_ty_val == 2) { - out_ty = BoolType::get(); - } else if (elem_ty_val == 3) { - out_ty = ComplexType::get(); - } else { - TORCH_CHECK( - false, - "Unsupported element type for tolist; only int, float, complex and bool are supported"); - } - - // Check that type of the Tensor matches that of the annotation. - // Make an exception for the case in which the annotated type is - // float/complex and the Tensor data type is also float/complex; - // the elements will be casted to double/c10::complex - // later. - TORCH_CHECK( - (out_ty == FloatType::get() && t.is_floating_point()) || - (out_ty == ComplexType::get() && t.is_complex()) || - tryScalarTypeFromJitType(out_ty) == t.scalar_type(), - "Output annotation element type and runtime tensor element type must match for tolist()"); - - // Check that the dimension of the Tensor matches that of the - // annotation. - TORCH_CHECK( - dim_val == t.dim(), - "Output annotation list dimension and runtime tensor dimension must match for tolist()"); - - // Wrap out_ty in a ListType dim times. - for (const auto i : c10::irange(dim_val)) { - (void)i; // Suppress unused variable warning - out_ty = ListType::create(out_ty); - } - - int64_t dim = t.dim(); - auto sizes = t.sizes(); - auto strides = t.strides(); - size_t element_size = t.element_size(); - char* data = static_cast(t.data_ptr()); - auto result = tensorToListRecursive( - data, - 0, - dim, - out_ty, - t.scalar_type(), - sizes, - strides, - element_size); - push(stack, std::move(result)); - }; - }, - aliasAnalysisSpecialCase()), - // only used internally in range() translation - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__range_length(int lo, int hi, int step) -> int"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t lo, hi, step; - pop(stack, lo, hi, step); - // error handling when step_val = 0 during runtime - if (step == 0) { - throw std::runtime_error("range() arg 3 must not be zero"); - } - if (step > 0 && lo < hi) { - push(stack, 1 + (hi - 1 - lo) / step); - } else if (step < 0 && lo > hi) { - push(stack, 1 + (lo - 1 - hi) / (0 - step)); - } else { - push(stack, 0); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__derive_index(int index, int start, int step) -> int"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t index, start, step; - pop(stack, index, start, step); - push(stack, start + index * step); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::TupleUnpack(Any tup) -> ..."), - [](Stack* stack) { tupleUnpack(*stack); }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::unchecked_cast(t x) -> t"), - noop, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::IntImplicit(Tensor a) -> int"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - checkImplicitTensorToNum(a, /*to int*/ true); - push(stack, a.item()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ComplexImplicit(Tensor a) -> complex"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - checkImplicitTensorToNum(a, /*to int*/ false); - push(stack, a.item>()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::FloatImplicit(Tensor a) -> float"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - checkImplicitTensorToNum(a, /*to int*/ false); - push(stack, a.item()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ScalarImplicit(Tensor a) -> Scalar"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - checkImplicitTensorToNum(a, /*to int*/ false); - push(stack, a.item()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Bool.Tensor(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_nonzero()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; - pop(stack, i); - push(stack, (bool)i); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Bool.float(float a) -> bool"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; - pop(stack, d); - push(stack, (bool)d); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Int.Tensor(Tensor a) -> int"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.item()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Int.bool(bool a) -> int"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; - pop(stack, b); - push(stack, static_cast(b)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Int.float(float a) -> int"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; - pop(stack, d); - push(stack, static_cast(d)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Int.Scalar(Scalar a) -> int"), - [](Stack* stack) { - IValue scalar; - pop(stack, scalar); - if (scalar.isInt()) { - push(stack, std::move(scalar)); - } else { - // toScalar() needed to avoid strict type check in IValue::toInt. - push(stack, static_cast(scalar.toScalar().toInt())); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Int.str(str a) -> int"), - [](Stack* stack) { - auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; - int64_t val = static_cast(c10::stoll(s->string(), &sz)); - if (sz == s->string().size()) { - push(stack, val); - } else { - std::stringstream error_str; - error_str << "invalid literal for int() " - << "with base 10: '" << s->string() << "'"; - throw std::runtime_error(error_str.str()); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Float.Tensor(Tensor a) -> float"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.item()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Float.Scalar(Scalar a) -> float"), - [](Stack* stack) { - IValue scalar; - pop(stack, scalar); - if (scalar.isDouble()) { - push(stack, std::move(scalar)); - } else if (scalar.isComplexDouble()) { - push(stack, scalar.toComplexDouble().real()); - } else { - push(stack, static_cast(scalar.toInt())); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Float.int(int a) -> float"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; - pop(stack, i); - push(stack, (float)i); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Float.bool(bool a) -> float"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; - pop(stack, b); - push(stack, (float)b); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Float.str(str a) -> float"), - [](Stack* stack) { - auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; - double b = c10::stod(s->string(), &sz); - if (sz == s->string().size()) { - push(stack, b); - } else { - std::stringstream error_str; - error_str << "could not convert string " - << "to float: '" << s->string() << "'"; - throw std::runtime_error(error_str.str()); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Complex.Scalar(Scalar a) -> complex"), - [](Stack* stack) { - IValue scalar; - pop(stack, scalar); - if (scalar.isComplexDouble()) { - push(stack, std::move(scalar)); - } else if (scalar.isDouble()) { - push(stack, c10::complex(scalar.toDouble(), 0)); - } else { - push(stack, c10::complex(scalar.toInt(), 0)); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::Complex.Tensor_Tensor(Tensor a, Tensor b) -> complex"), - [](Stack* stack) { - at::Tensor a, b; - pop(stack, a, b); - push( - stack, c10::complex(a.item(), b.item())); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::format(str self, ...) -> str"), - [](Stack* stack) { - size_t num_inputs = pop(stack).toInt(); - format(*stack, num_inputs); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::einsum.sublist(Tensor a, ...) -> Tensor"), - [](Stack* stack) { - size_t num_inputs = pop(stack).toInt(); - einsum(*stack, num_inputs); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"), - [](Stack* stack) { - at::Scalar s; - pop(stack, s); - push(stack, at::scalar_to_tensor(s)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"), - [](Stack* stack) { throw JITException(pop(stack).toStringRef()); }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Size(int[] sizes) -> int[]"), - [](Stack* stack) {}, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"), - [](Stack* stack) { - auto t = std::move(pop(stack)).toTensor(); - pack(stack, t.sizes().vec()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"), - [](Stack* stack) { - IValue e = pop(stack); - push(stack, e.toEnumHolder()->name()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::EnumValue.int(AnyEnumType enum) -> int"), - [](Stack* stack) { - IValue e = pop(stack); - push(stack, e.toEnumHolder()->value()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "prim::EnumValue.float(AnyEnumType enum) -> float"), - [](Stack* stack) { - IValue e = pop(stack); - push(stack, e.toEnumHolder()->value()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::EnumValue.str(AnyEnumType enum) -> str"), - [](Stack* stack) { - IValue e = pop(stack); - push(stack, e.toEnumHolder()->value()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - // note the compiler knows to type TupleIndex more accurately than it - // is listed here. - TORCH_SELECTIVE_SCHEMA("prim::TupleIndex(Any tup, int i) -> Any"), - [](Stack* stack) { - int64_t index = pop(stack).toInt(); - auto tuple = pop(stack).toTuple(); - auto norm_index = normalizeIndex(index, tuple->elements().size()); - if (norm_index < 0 || - norm_index > static_cast(tuple->elements().size())) { - throw std::out_of_range("Tuple list index out of range"); - } - stack->emplace_back(tuple->elements()[norm_index]); - }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ne.int_list(int[] a, int[] b) -> bool"), - listNe, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)"), - noop, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"), - [](Stack* stack) { push(stack, pop(stack).toTensor().device()); }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, static_cast(a.scalar_type())); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"), - [](Stack* stack) { push(stack, !pop(stack).toBool()); }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::__is__(t1 self, t2 obj) -> bool"), - [](Stack* stack) { - IValue self, obj; - pop(stack, self, obj); - push(stack, self.is(obj)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::__isnot__(t1 self, t2 obj) -> bool"), - [](Stack* stack) { - IValue self, obj; - pop(stack, self, obj); - push(stack, !self.is(obj)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::element_size(Tensor self) -> int"), - [](Stack* stack) { - at::Tensor arg = pop(stack).toTensor(); - push(stack, arg.element_size()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::numel(Tensor self) -> int"), - [](Stack* stack) { - at::Tensor arg = pop(stack).toTensor(); - push(stack, arg.numel()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"), - [](Stack* stack) { - at::Tensor arg = pop(stack).toTensor(); - push(stack, arg.dim()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), - [](Stack* stack) { - RECORD_FUNCTION("get_device", std::vector()); - auto result = - at::get_device((std::move(peek(stack, 0, 1))).toTensor()); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::storage_offset(Tensor self) -> int"), - [](Stack* stack) { - RECORD_FUNCTION("storage_offset", std::vector()); - auto result = - ((std::move(peek(stack, 0, 1))).toTensor()).storage_offset(); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::is_contiguous(Tensor self) -> bool"), - [](Stack* stack) { - RECORD_FUNCTION("is_contiguous", std::vector()); - auto result = - ((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous(); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - // these ops are generic over the list element type. - // CREATING GENERIC_LIST_OPS - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::select.t(t[](a) list, int idx) -> t(*)"), - listSelect, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__getitem__.t(t[](a) list, int idx) -> t(*)"), - listSelect, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"), - listAppend, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::reverse.t(t[](a!) self) -> ()"), - listReverse, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::extend.t(t[](a!) self, t[] other) -> ()"), - listExtend, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::copy.t(t[](a) self) -> t[]"), - listCopy, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::_set_item.t(t [](a!) l, int idx, t(b -> *) el) -> t[](a!)"), - listSetItem, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::clear.t(t[](a!) self) -> ()"), - listClear, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::Delete.t(t[](a!) self, int idx) -> ()"), - listDelete, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::insert.t(t[](a!) self, int idx, t(b -> *) el) -> ()"), - listInsert, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::pop.t(t[](a!) self, int idx=-1) -> t(*)"), - listPop, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::add.t(t[] a, t[] b) -> t[]"), - listAdd, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::add_.t(t[](a!) self, t[] b) -> t[]"), - listInplaceAdd, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]"), - listSlice, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::list.t(t[] l) -> t[]"), - listList, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::mul.left_t(t[] l, int n) -> t[]"), - listMulIntLeft, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::mul.right_(int n, t[] l) -> t[]"), - listMulIntRight, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::mul_.t(t[](a!) l, int n) -> t[](a!)"), - listMulIntLeftInPlace, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::len.t(t[] a) -> int"), - listLen, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::eq.int_list(int[] a, int[] b) -> bool"), - listEq, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::eq.device(Device a, Device b) -> bool"), - [](Stack* stack) { - auto a = pop(stack).toDevice(); - auto b = pop(stack).toDevice(); - push(stack, a == b); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ne.device(Device a, Device b) -> bool"), - [](Stack* stack) { - auto a = pop(stack).toDevice(); - auto b = pop(stack).toDevice(); - push(stack, a != b); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::eq.bool(bool a, bool b) -> bool"), - [](Stack* stack) { - auto a = pop(stack); - auto b = pop(stack); - push(stack, a == b); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ne.bool(bool a, bool b) -> bool"), - [](Stack* stack) { - auto a = pop(stack); - auto b = pop(stack); - push(stack, a != b); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), - [](Stack* stack) { push(stack, IValue::uninitialized()); }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::Print(...) -> ()"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - std::stringstream ss; - bool first = true; - for (const IValue& i : last(stack, num_inputs)) { - if (!first) - ss << " "; - first = false; - ss << i; - } - drop(stack, num_inputs); - ss << std::endl; - auto* handler = getPrintHandler(); - TORCH_INTERNAL_ASSERT(handler); - handler(ss.str()); - }, - aliasAnalysisSpecialCase()), - // This is an alternative to aten::cat op that takes variable number of - // parameters as input. - // Format: - // prim::VarConcat(Tensors..., dim) -> Tensor - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::VarConcat(...) -> Tensor"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - auto dim = pop(stack).toInt(); - std::vector inputs(num_inputs - 1); - for (int i = 0; i < num_inputs - 1; ++i) { - inputs[num_inputs - 2 - i] = pop(stack).toTensor(); - } - push(stack, at::cat(inputs, dim)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"), - [](Stack* stack) { - IValue x = pop(stack); - IValue y = pop(stack); - push(stack, x == y); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::ne.enum(AnyEnumType a, AnyEnumType b) -> bool"), - [](Stack* stack) { - IValue x = pop(stack); - IValue y = pop(stack); - push(stack, x != y); - }, - aliasAnalysisFromSchema()), - // We define aten::dequantize in both native_functions.yaml and here, - // however, aten::dequantize.any defined here overrides - // aten::dequantize.tensors in native_functions.yaml. The variants here - // are only for graph mode quantization, and they should be removed once - // we deprecate graph mode quantization, and use the variants in - // native_functions.yaml. - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::dequantize.tensor(Tensor qtensor) -> Tensor"), - [](Stack* stack) { - at::Tensor qtensor; - pop(stack, qtensor); - push(stack, at::dequantize(qtensor)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::dequantize.list(Tensor[] qtensors) -> Tensor[]"), - [](Stack* stack) { - auto qtensors = pop(stack).toTensorVector(); - push(stack, at::dequantize(qtensors)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::dequantize.any(Any tensors) -> Any"), - [](Stack* stack) { dequantize(*stack); }, - aliasAnalysisFromSchema()), - DEFINE_UNARY_OP_WITH_COMPLEX(aten::log, std::log(a), float, float), - DEFINE_STRING_OP(aten::add, a + b, str), - DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::eq, a == b), - DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::ne, a != b), - DEFINE_GENERIC_OP( - aten::polar, - c10::polar(static_cast(a), static_cast(b)), - c10::polar(static_cast(a), static_cast(b)), - complex, - complex), - DEFINE_INT_FLOAT_OP( - aten::polar, - c10::polar(static_cast(a), static_cast(b)), - complex), - DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( - aten::polar, - c10::polar(static_cast(a), static_cast(b)), - c10::polar(static_cast(a), static_cast(b)), - Scalar), - DEFINE_COMPARISON_OP(aten::lt, a < b), - DEFINE_COMPARISON_OP(aten::gt, a > b), - DEFINE_COMPARISON_OP(aten::le, a <= b), - DEFINE_COMPARISON_OP(aten::ge, a >= b), - DEFINE_BINARY_OP_WITH_COMPLEX(aten::add, a + b), - DEFINE_BINARY_OP_WITH_COMPLEX(aten::sub, a - b), - DEFINE_BINARY_OP_WITH_COMPLEX(aten::mul, a* b), - DEFINE_BOOL_OP(aten::__and__, a&& b), - DEFINE_BOOL_OP(aten::__or__, a || b), - DEFINE_BOOL_OP(aten::__xor__, a != b), - DEFINE_UNARY_OP(aten::round, round_to_even(a), float, float), - DEFINE_UNARY_OP(aten::floor, floor(a), int, int), - DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int), - DEFINE_UNARY_OP_WITH_COMPLEX(aten::neg, -a, int, float), - DEFINE_UNARY_OP_WITH_COMPLEX(aten::exp, std::exp(a), float, float), - // Pass in two ops for handling int and float separately as % in C++ only - // works for int The modulus calculation is different between C++ and - // Python (on negative), we preserve the python behavior as it's more - // common and match python syntax, hence the conversion. - DEFINE_GENERIC_OP( - aten::remainder, - (b + (a % b)) % b, - fmod((b + fmod(a, b)), b), - int, - float), - DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float), - DEFINE_SCALAR_BINARY_OP( - aten::remainder, - (b + (a % b)) % b, - fmod((b + fmod(a, b)), b), - Scalar), - // NB: This is the python truediv operation - DEFINE_GENERIC_OP_WITH_COMPLEX( - aten::div, - static_cast(a) / static_cast(b), - a / b, - a / b, - float, - float, - complex), - DEFINE_SCALAR_BINARY_OP( - aten::div, - static_cast(a) / static_cast(b), - a / b, - float), - DEFINE_GENERIC_OP( - aten::floordiv, - floordiv(a, b), - std::floor(a / b), - int, - float), - DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float), - DEFINE_SCALAR_BINARY_OP( - aten::floordiv, - floordiv(a, b), - std::floor(a / b), - Scalar), - // int ** int produces a float, because negative exponents produce float - // results - DEFINE_GENERIC_OP_WITH_COMPLEX( - aten::pow, - static_cast(powWrapper(a, b)), - static_cast(powWrapper(a, b)), - static_cast>(pow(a, b)), - float, - float, - complex), - DEFINE_INT_FLOAT_OP( - aten::pow, - static_cast(powWrapper(a, b)), - float), - DEFINE_FLOAT_COMPLEX_OP(aten::pow, pow(a, b), complex), - DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( - aten::pow, - static_cast(pow(a, b)), - static_cast(pow(a, b)), - float), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::pow.int_to_int(int a, int b) -> int"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t a, b; - pop(stack, a, b); - push(stack, powWrapper(a, b)); - }, - aliasAnalysisFromSchema()), - // min and max are in prim:: because there is a difference between - // the python builtin 'min' and 'torch.min' - DEFINE_BINARY_OP(prim::min, a < b ? a : b), - DEFINE_BINARY_OP(prim::max, a > b ? a : b), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::type(Device self) -> str"), - [](Stack* stack) { - auto d = pop(stack); - push( - stack, - DeviceTypeName(d.toDevice().type(), /* lower_case=*/true)); - }, - aliasAnalysisFromSchema()), - // tensor length op (size of 1st dimension) - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::len.Tensor(Tensor t) -> int"), - [](Stack* stack) { - at::Tensor t = pop(stack).toTensor(); - if (t.dim() == 0) { - AT_ERROR("len() of a 0-d tensor"); - } - push(stack, t.sizes()[0]); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::ord(str string) -> int"), - [](Stack* stack) { - auto string = pop(stack).toStringRef(); - TORCH_CHECK( - string.size() == 1, - "String for ord() must be 1 character, found ", - string.size()); - uint8_t ord = string.at(0); - push(stack, int64_t(ord)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::lower(str self) -> str"), - [](Stack* stack) { - auto string = pop(stack).toStringRef(); - std::stringstream ss; - for (char c : string) { - ss << static_cast(::tolower(c)); - } - push(stack, ss.str()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__contains__.int_list(int[] l, int item) -> bool"), - listContains, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__contains__.str_list(str[] l, str item) -> bool"), - listContains, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::len.str(str s) -> int"), - [](Stack* stack) { - auto string = pop(stack).toStringRef(); - push(stack, static_cast(string.size())); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::dict() -> Dict(str, Tensor)"), - [](Stack* stack) { - auto dict = - c10::impl::GenericDict(StringType::get(), TensorType::get()); - push(stack, dict); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::__getitem__.str(str s, int index) -> str"), - [](Stack* stack) { - auto index = pop(stack).toInt(); - auto string = pop(stack).toStringRef(); - auto norm_index = normalizeIndex(index, string.size()); - char c = string.at(norm_index); - push(stack, std::string(&c, 1)); - }, - aliasAnalysisFromSchema()), +static const OperatorGeneratorArgs opGenArgs[] = { + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::str(t elem) -> str"), + [](Stack& stack) { + std::stringstream ss; + ss << pop(stack); + push(stack, ss.str()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::list(str t) -> str[]"), + [](Stack& stack) { + auto str = pop(stack).toStringRef(); + c10::List chars; + chars.reserve(str.size()); + for (auto c : str) { + chars.push_back(std::string(1, c)); + } + push(stack, std::move(chars)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::cpu(Tensor(a) self) -> Tensor(a|b)"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.cpu()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::layout(Tensor a) -> int"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.layout()); + }, + aliasAnalysisFromSchema()), + + // only used internally in range() translation + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__range_length(int lo, int hi, int step) -> int"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t lo, hi, step; + pop(stack, lo, hi, step); + // error handling when step_val = 0 during runtime + if (step == 0) { + throw std::runtime_error("range() arg 3 must not be zero"); + } + if (step > 0 && lo < hi) { + push(stack, 1 + (hi - 1 - lo) / step); + } else if (step < 0 && lo > hi) { + push(stack, 1 + (lo - 1 - hi) / (0 - step)); + } else { + push(stack, 0); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__derive_index(int index, int start, int step) -> int"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t index, start, step; + pop(stack, index, start, step); + push(stack, start + index * step); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::TupleUnpack(Any tup) -> ..."), + [](Stack& stack) { tupleUnpack(stack); }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::unchecked_cast(t x) -> t"), + noop, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::IntImplicit(Tensor a) -> int"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + checkImplicitTensorToNum(a, /*to int*/ true); + push(stack, a.item()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ComplexImplicit(Tensor a) -> complex"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + checkImplicitTensorToNum(a, /*to int*/ false); + push(stack, a.item>()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::FloatImplicit(Tensor a) -> float"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + checkImplicitTensorToNum(a, /*to int*/ false); + push(stack, a.item()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ScalarImplicit(Tensor a) -> Scalar"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + checkImplicitTensorToNum(a, /*to int*/ false); + push(stack, a.item()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Bool.Tensor(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_nonzero()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t i; + pop(stack, i); + push(stack, (bool)i); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Bool.float(float a) -> bool"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double d; + pop(stack, d); + push(stack, (bool)d); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Int.Tensor(Tensor a) -> int"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.item()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Int.bool(bool a) -> int"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool b; + pop(stack, b); + push(stack, static_cast(b)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Int.float(float a) -> int"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + double d; + pop(stack, d); + push(stack, static_cast(d)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Int.Scalar(Scalar a) -> int"), + [](Stack& stack) { + IValue scalar; + pop(stack, scalar); + if (scalar.isInt()) { + push(stack, std::move(scalar)); + } else { + // toScalar() needed to avoid strict type check in IValue::toInt. + push(stack, static_cast(scalar.toScalar().toInt())); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Int.str(str a) -> int"), + [](Stack& stack) { + auto s = pop(stack).toString(); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::string::size_type sz; + int64_t val = static_cast(c10::stoll(s->string(), &sz)); + if (sz == s->string().size()) { + push(stack, val); + } else { + std::stringstream error_str; + error_str << "invalid literal for int() " + << "with base 10: '" << s->string() << "'"; + throw std::runtime_error(error_str.str()); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Float.Tensor(Tensor a) -> float"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.item()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Float.Scalar(Scalar a) -> float"), + [](Stack& stack) { + IValue scalar; + pop(stack, scalar); + if (scalar.isDouble()) { + push(stack, std::move(scalar)); + } else if (scalar.isComplexDouble()) { + push(stack, scalar.toComplexDouble().real()); + } else { + push(stack, static_cast(scalar.toInt())); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Float.int(int a) -> float"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t i; + pop(stack, i); + push(stack, (float)i); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Float.bool(bool a) -> float"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool b; + pop(stack, b); + push(stack, (float)b); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Float.str(str a) -> float"), + [](Stack& stack) { + auto s = pop(stack).toString(); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + std::string::size_type sz; + double b = c10::stod(s->string(), &sz); + if (sz == s->string().size()) { + push(stack, b); + } else { + std::stringstream error_str; + error_str << "could not convert string " + << "to float: '" << s->string() << "'"; + throw std::runtime_error(error_str.str()); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Complex.Scalar(Scalar a) -> complex"), + [](Stack& stack) { + IValue scalar; + pop(stack, scalar); + if (scalar.isComplexDouble()) { + push(stack, std::move(scalar)); + } else if (scalar.isDouble()) { + push(stack, c10::complex(scalar.toDouble(), 0)); + } else { + push(stack, c10::complex(scalar.toInt(), 0)); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::Complex.Tensor_Tensor(Tensor a, Tensor b) -> complex"), + [](Stack& stack) { + at::Tensor a, b; + pop(stack, a, b); + push(stack, c10::complex(a.item(), b.item())); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::format(str self, ...) -> str"), + [](Stack& stack) { + size_t num_inputs = pop(stack).toInt(); + format(stack, num_inputs); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::einsum.sublist(Tensor a, ...) -> Tensor"), + [](Stack& stack) { + size_t num_inputs = pop(stack).toInt(); + einsum(stack, num_inputs); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"), + [](Stack& stack) { + at::Scalar s; + pop(stack, s); + push(stack, at::scalar_to_tensor(s)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"), + [](Stack& stack) { throw JITException(pop(stack).toStringRef()); }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Size(int[] sizes) -> int[]"), + [](Stack& stack) {}, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"), + [](Stack& stack) { + auto t = std::move(pop(stack)).toTensor(); + pack(stack, t.sizes().vec()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"), + [](Stack& stack) { + IValue e = pop(stack); + push(stack, e.toEnumHolder()->name()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::EnumValue.int(AnyEnumType enum) -> int"), + [](Stack& stack) { + IValue e = pop(stack); + push(stack, e.toEnumHolder()->value()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "prim::EnumValue.float(AnyEnumType enum) -> float"), + [](Stack& stack) { + IValue e = pop(stack); + push(stack, e.toEnumHolder()->value()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::EnumValue.str(AnyEnumType enum) -> str"), + [](Stack& stack) { + IValue e = pop(stack); + push(stack, e.toEnumHolder()->value()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + // note the compiler knows to type TupleIndex more accurately than it + // is listed here. + TORCH_SELECTIVE_SCHEMA("prim::TupleIndex(Any tup, int i) -> Any"), + [](Stack& stack) { + int64_t index = pop(stack).toInt(); + auto tuple = pop(stack).toTuple(); + auto norm_index = normalizeIndex(index, tuple->elements().size()); + if (norm_index < 0 || + norm_index > static_cast(tuple->elements().size())) { + throw std::out_of_range("Tuple list index out of range"); + } + stack.emplace_back(tuple->elements()[norm_index]); + }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ne.int_list(int[] a, int[] b) -> bool"), + listNe, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)"), + noop, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"), + [](Stack& stack) { push(stack, pop(stack).toTensor().device()); }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, static_cast(a.scalar_type())); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"), + [](Stack& stack) { push(stack, !pop(stack).toBool()); }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::__is__(t1 self, t2 obj) -> bool"), + [](Stack& stack) { + IValue self, obj; + pop(stack, self, obj); + push(stack, self.is(obj)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::__isnot__(t1 self, t2 obj) -> bool"), + [](Stack& stack) { + IValue self, obj; + pop(stack, self, obj); + push(stack, !self.is(obj)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::element_size(Tensor self) -> int"), + [](Stack& stack) { + at::Tensor arg = pop(stack).toTensor(); + push(stack, arg.element_size()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::numel(Tensor self) -> int"), + [](Stack& stack) { + at::Tensor arg = pop(stack).toTensor(); + push(stack, arg.numel()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"), + [](Stack& stack) { + at::Tensor arg = pop(stack).toTensor(); + push(stack, arg.dim()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), + [](Stack& stack) { + RECORD_FUNCTION("get_device", std::vector()); + auto result = + at::get_device((std::move(peek(stack, 0, 1))).toTensor()); + drop(stack, 1); + pack(stack, result); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::storage_offset(Tensor self) -> int"), + [](Stack& stack) { + RECORD_FUNCTION("storage_offset", std::vector()); + auto result = + ((std::move(peek(stack, 0, 1))).toTensor()).storage_offset(); + drop(stack, 1); + pack(stack, result); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::is_contiguous(Tensor self) -> bool"), + [](Stack& stack) { + RECORD_FUNCTION("is_contiguous", std::vector()); + auto result = + ((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous(); + drop(stack, 1); + pack(stack, result); + }, + aliasAnalysisFromSchema()), + // these ops are generic over the list element type. + // CREATING GENERIC_LIST_OPS + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::select.t(t[](a) list, int idx) -> t(*)"), + listSelect, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__getitem__.t(t[](a) list, int idx) -> t(*)"), + listSelect, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"), + listAppend, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::reverse.t(t[](a!) self) -> ()"), + listReverse, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::extend.t(t[](a!) self, t[] other) -> ()"), + listExtend, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::copy.t(t[](a) self) -> t[]"), + listCopy, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_set_item.t(t [](a!) l, int idx, t(b -> *) el) -> t[](a!)"), + listSetItem, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::clear.t(t[](a!) self) -> ()"), + listClear, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::Delete.t(t[](a!) self, int idx) -> ()"), + listDelete, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::insert.t(t[](a!) self, int idx, t(b -> *) el) -> ()"), + listInsert, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::pop.t(t[](a!) self, int idx=-1) -> t(*)"), + listPop, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::add.t(t[] a, t[] b) -> t[]"), + listAdd, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::add_.t(t[](a!) self, t[] b) -> t[]"), + listInplaceAdd, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]"), + listSlice, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::list.t(t[] l) -> t[]"), + listList, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::mul.left_t(t[] l, int n) -> t[]"), + listMulIntLeft, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::mul.right_(int n, t[] l) -> t[]"), + listMulIntRight, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::mul_.t(t[](a!) l, int n) -> t[](a!)"), + listMulIntLeftInPlace, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::len.t(t[] a) -> int"), + listLen, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::eq.int_list(int[] a, int[] b) -> bool"), + listEq, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::eq.device(Device a, Device b) -> bool"), + [](Stack& stack) { + auto a = pop(stack).toDevice(); + auto b = pop(stack).toDevice(); + push(stack, a == b); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ne.device(Device a, Device b) -> bool"), + [](Stack& stack) { + auto a = pop(stack).toDevice(); + auto b = pop(stack).toDevice(); + push(stack, a != b); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::eq.bool(bool a, bool b) -> bool"), + [](Stack& stack) { + auto a = pop(stack); + auto b = pop(stack); + push(stack, a == b); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ne.bool(bool a, bool b) -> bool"), + [](Stack& stack) { + auto a = pop(stack); + auto b = pop(stack); + push(stack, a != b); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), + [](Stack& stack) { push(stack, IValue::uninitialized()); }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::Print(...) -> ()"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + std::stringstream ss; + bool first = true; + for (const IValue& i : last(stack, num_inputs)) { + if (!first) + ss << " "; + first = false; + ss << i; + } + drop(stack, num_inputs); + ss << std::endl; + auto* handler = getPrintHandler(); + TORCH_INTERNAL_ASSERT(handler); + handler(ss.str()); + }, + aliasAnalysisSpecialCase()), + // This is an alternative to aten::cat op that takes variable number of + // parameters as input. + // Format: + // prim::VarConcat(Tensors..., dim) -> Tensor + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::VarConcat(...) -> Tensor"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + auto dim = pop(stack).toInt(); + std::vector inputs(num_inputs - 1); + for (int i = 0; i < num_inputs - 1; ++i) { + inputs[num_inputs - 2 - i] = pop(stack).toTensor(); + } + push(stack, at::cat(inputs, dim)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::VarStack(...) -> Tensor"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + auto dim = pop(stack).toInt(); + std::vector inputs(num_inputs - 1); + for (int i = 0; i < num_inputs - 1; ++i) { + inputs[num_inputs - 2 - i] = pop(stack).toTensor(); + } + push(stack, at::stack(inputs, dim)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"), + [](Stack& stack) { + IValue x = pop(stack); + IValue y = pop(stack); + push(stack, x == y); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::ne.enum(AnyEnumType a, AnyEnumType b) -> bool"), + [](Stack& stack) { + IValue x = pop(stack); + IValue y = pop(stack); + push(stack, x != y); + }, + aliasAnalysisFromSchema()), + // We define aten::dequantize in both native_functions.yaml and here, + // however, aten::dequantize.any defined here overrides + // aten::dequantize.tensors in native_functions.yaml. The variants here + // are only for graph mode quantization, and they should be removed once + // we deprecate graph mode quantization, and use the variants in + // native_functions.yaml. + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::dequantize.tensor(Tensor qtensor) -> Tensor"), + [](Stack& stack) { + at::Tensor qtensor; + pop(stack, qtensor); + push(stack, at::dequantize(qtensor)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::dequantize.list(Tensor[] qtensors) -> Tensor[]"), + [](Stack& stack) { + auto qtensors = pop(stack).toTensorVector(); + push(stack, at::dequantize(qtensors)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::dequantize.any(Any tensors) -> Any"), + [](Stack& stack) { dequantize(stack); }, + aliasAnalysisFromSchema()), + DEFINE_UNARY_OP_WITH_COMPLEX(aten::log, std::log(a), float, float), + DEFINE_STRING_OP(aten::add, a + b, str), + DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::eq, a == b), + DEFINE_COMPARISON_OP_WITH_COMPLEX(aten::ne, a != b), + DEFINE_GENERIC_OP( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + c10::polar(static_cast(a), static_cast(b)), + complex, + complex), + DEFINE_INT_FLOAT_OP( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + complex), + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( + aten::polar, + c10::polar(static_cast(a), static_cast(b)), + c10::polar(static_cast(a), static_cast(b)), + Scalar), + DEFINE_COMPARISON_OP(aten::lt, a < b), + DEFINE_COMPARISON_OP(aten::gt, a > b), + DEFINE_COMPARISON_OP(aten::le, a <= b), + DEFINE_COMPARISON_OP(aten::ge, a >= b), + DEFINE_BINARY_OP_WITH_COMPLEX(aten::add, a + b), + DEFINE_BINARY_OP_WITH_COMPLEX(aten::sub, a - b), + DEFINE_BINARY_OP_WITH_COMPLEX(aten::mul, a* b), + DEFINE_BOOL_OP(aten::__and__, a&& b), + DEFINE_BOOL_OP(aten::__or__, a || b), + DEFINE_BOOL_OP(aten::__xor__, a != b), + DEFINE_UNARY_OP(aten::round, round_to_even(a), float, float), + DEFINE_UNARY_OP(aten::floor, floor(a), int, int), + DEFINE_UNARY_OP(aten::ceil, ceil(a), int, int), + DEFINE_UNARY_OP_WITH_COMPLEX(aten::neg, -a, int, float), + DEFINE_UNARY_OP_WITH_COMPLEX(aten::exp, std::exp(a), float, float), + // Pass in two ops for handling int and float separately as % in C++ only + // works for int The modulus calculation is different between C++ and + // Python (on negative), we preserve the python behavior as it's more + // common and match python syntax, hence the conversion. + DEFINE_GENERIC_OP( + aten::remainder, + (b + (a % b)) % b, + fmod((b + fmod(a, b)), b), + int, + float), + DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float), + DEFINE_SCALAR_BINARY_OP( + aten::remainder, + (b + (a % b)) % b, + fmod((b + fmod(a, b)), b), + Scalar), + // NB: This is the python truediv operation + DEFINE_GENERIC_OP_WITH_COMPLEX( + aten::div, + static_cast(a) / static_cast(b), + a / b, + a / b, + float, + float, + complex), + DEFINE_SCALAR_BINARY_OP( + aten::div, + static_cast(a) / static_cast(b), + a / b, + float), + DEFINE_GENERIC_OP( + aten::floordiv, + floordiv(a, b), + std::floor(a / b), + int, + float), + DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float), + DEFINE_SCALAR_BINARY_OP( + aten::floordiv, + floordiv(a, b), + std::floor(a / b), + Scalar), + // int ** int produces a float, because negative exponents produce float + // results + DEFINE_GENERIC_OP_WITH_COMPLEX( + aten::pow, + static_cast(powWrapper(a, b)), + static_cast(powWrapper(a, b)), + static_cast>(pow(a, b)), + float, + float, + complex), + DEFINE_INT_FLOAT_OP( + aten::pow, + static_cast(powWrapper(a, b)), + float), + DEFINE_FLOAT_COMPLEX_OP(aten::pow, pow(a, b), complex), + DEFINE_SCALAR_BINARY_OP_AVOID_COLLISION( + aten::pow, + static_cast(pow(a, b)), + static_cast(pow(a, b)), + float), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::pow.int_to_int(int a, int b) -> int"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t a, b; + pop(stack, a, b); + push(stack, powWrapper(a, b)); + }, + aliasAnalysisFromSchema()), + // min and max are in prim:: because there is a difference between + // the python builtin 'min' and 'torch.min' + DEFINE_BINARY_OP(prim::min, a < b ? a : b), + DEFINE_BINARY_OP(prim::max, a > b ? a : b), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::type(Device self) -> str"), + [](Stack& stack) { + auto d = pop(stack); + push( + stack, DeviceTypeName(d.toDevice().type(), /* lower_case=*/true)); + }, + aliasAnalysisFromSchema()), + // tensor length op (size of 1st dimension) + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::len.Tensor(Tensor t) -> int"), + [](Stack& stack) { + at::Tensor t = pop(stack).toTensor(); + if (t.dim() == 0) { + AT_ERROR("len() of a 0-d tensor"); + } + push(stack, t.sizes()[0]); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::ord(str string) -> int"), + [](Stack& stack) { + auto string = pop(stack).toStringRef(); + TORCH_CHECK( + string.size() == 1, + "String for ord() must be 1 character, found ", + string.size()); + uint8_t ord = string.at(0); + push(stack, int64_t(ord)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::lower(str self) -> str"), + [](Stack& stack) { + auto string = pop(stack).toStringRef(); + std::stringstream ss; + for (char c : string) { + ss << static_cast(::tolower(c)); + } + push(stack, ss.str()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__contains__.int_list(int[] l, int item) -> bool"), + listContains, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__contains__.str_list(str[] l, str item) -> bool"), + listContains, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::len.str(str s) -> int"), + [](Stack& stack) { + auto string = pop(stack).toStringRef(); + push(stack, static_cast(string.size())); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::dict() -> Dict(str, Tensor)"), + [](Stack& stack) { + auto dict = + c10::impl::GenericDict(StringType::get(), TensorType::get()); + push(stack, dict); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::__getitem__.str(str s, int index) -> str"), + [](Stack& stack) { + auto index = pop(stack).toInt(); + auto string = pop(stack).toStringRef(); + auto norm_index = normalizeIndex(index, string.size()); + char c = string.at(norm_index); + push(stack, std::string(&c, 1)); + }, + aliasAnalysisFromSchema()), #define CREATE_COPY_OP(other_type, c_type) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::copy_." #other_type \ "(Tensor(a!) self, " #other_type \ " other) -> Tensor(a!)"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ at::Tensor t; \ c_type other; \ pop(stack, t, other); \ @@ -1023,172 +950,170 @@ RegisterOperators reg( }, \ aliasAnalysisFromSchema()) - CREATE_COPY_OP(Tensor, at::Tensor), - CREATE_COPY_OP(int, int64_t), - CREATE_COPY_OP(float, double), + CREATE_COPY_OP(Tensor, at::Tensor), + CREATE_COPY_OP(int, int64_t), + CREATE_COPY_OP(float, double), #undef CREATE_COPY_OP - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()"), - [](Stack* stack) { - bool create_graph = pop(stack).toBool(); - auto retain_graph = pop(stack).toOptional(); - IValue gradient_ivalue = pop(stack); - at::Tensor gradient = gradient_ivalue.isNone() - ? at::Tensor() - : gradient_ivalue.toTensor(); - at::Tensor self = pop(stack).toTensor(); - bool keep_graph = retain_graph ? retain_graph.value() : create_graph; - self.backward(gradient, keep_graph, create_graph); - }, - aliasAnalysisConservative()), - // - // create a clone of these declarations with a _hacked_twin overload name - // and nullability scrubbed from TensorList arg types - // TOOD find out why this exists and how to do it without the hack - // - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), - [](Stack* stack) { - auto indices = pop(stack).to>>(); - auto self = pop(stack).toTensor(); - auto result = at::index(self, indices); - push(stack, std::move(result)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"), - [](Stack* stack) { - auto unsafe = pop(stack).toBool(); - auto accumulate = pop(stack).toBool(); - auto values = pop(stack).toTensor(); - auto indices = pop(stack).to>>(); - auto self = pop(stack).toTensor(); - auto result = - at::_index_put_impl_(self, indices, values, accumulate, unsafe); - push(stack, std::move(result)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"), - [](Stack* stack) { - auto accumulate = pop(stack).toBool(); - auto values = pop(stack).toTensor(); - auto indices = pop(stack).to>>(); - auto self = pop(stack).toTensor(); - auto result = at::index_put_(self, indices, values, accumulate); - push(stack, std::move(result)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"), - [](Stack* stack) { - auto accumulate = pop(stack).toBool(); - auto values = pop(stack).toTensor(); - auto indices = pop(stack).to>>(); - auto self = pop(stack).toTensor(); - auto result = at::index_put_(self, indices, values, accumulate); - push(stack, std::move(result)); - }, - aliasAnalysisFromSchema()), - // reference function parse_to_conversion in python_arg_parsing.h - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; - pop(stack, non_blocking, copy); - c10::optional scalarType = - pop(stack).toOptional(); - c10::optional device = - pop(stack).toOptional(); - at::Tensor self = pop(stack).toTensor(); - push( - stack, - to_dispatch(self, device, scalarType, non_blocking, copy)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; - pop(stack, non_blocking, copy); - c10::optional scalarType = - pop(stack).toOptional(); - c10::optional device = c10::nullopt; - at::Tensor self = pop(stack).toTensor(); - push( - stack, - to_dispatch(self, device, scalarType, non_blocking, copy)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_cuda(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_cuda()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_xpu()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::data(Tensor(a) a) -> Tensor(a)"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, autograd::Variable(a).variable_data()); - }, - aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()"), + [](Stack& stack) { + bool create_graph = pop(stack).toBool(); + auto retain_graph = pop(stack).toOptional(); + IValue gradient_ivalue = pop(stack); + at::Tensor gradient = gradient_ivalue.isNone() + ? at::Tensor() + : gradient_ivalue.toTensor(); + at::Tensor self = pop(stack).toTensor(); + bool keep_graph = retain_graph ? retain_graph.value() : create_graph; + self.backward(gradient, keep_graph, create_graph); + }, + aliasAnalysisConservative()), + // + // create a clone of these declarations with a _hacked_twin overload name + // and nullability scrubbed from TensorList arg types + // TOOD find out why this exists and how to do it without the hack + // + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), + [](Stack& stack) { + auto indices = pop(stack).to>>(); + auto self = pop(stack).toTensor(); + auto result = at::index(self, indices); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"), + [](Stack& stack) { + auto unsafe = pop(stack).toBool(); + auto accumulate = pop(stack).toBool(); + auto values = pop(stack).toTensor(); + auto indices = pop(stack).to>>(); + auto self = pop(stack).toTensor(); + auto result = + at::_index_put_impl_(self, indices, values, accumulate, unsafe); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"), + [](Stack& stack) { + auto accumulate = pop(stack).toBool(); + auto values = pop(stack).toTensor(); + auto indices = pop(stack).to>>(); + auto self = pop(stack).toTensor(); + auto result = at::index_put_(self, indices, values, accumulate); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"), + [](Stack& stack) { + auto accumulate = pop(stack).toBool(); + auto values = pop(stack).toTensor(); + auto indices = pop(stack).to>>(); + auto self = pop(stack).toTensor(); + auto result = at::index_put_(self, indices, values, accumulate); + push(stack, std::move(result)); + }, + aliasAnalysisFromSchema()), + // reference function parse_to_conversion in python_arg_parsing.h + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool non_blocking; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool copy; + pop(stack, non_blocking, copy); + c10::optional scalarType = + pop(stack).toOptional(); + c10::optional device = + pop(stack).toOptional(); + at::Tensor self = pop(stack).toTensor(); + push( + stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool non_blocking; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool copy; + pop(stack, non_blocking, copy); + c10::optional scalarType = + pop(stack).toOptional(); + c10::optional device = c10::nullopt; + at::Tensor self = pop(stack).toTensor(); + push( + stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_cuda(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_cuda()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_xpu()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::data(Tensor(a) a) -> Tensor(a)"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, autograd::Variable(a).variable_data()); + }, + aliasAnalysisFromSchema()), // these ops are not defined for Tensor #define CREATE_COMPARATOR_LIST_OPS_SPECIALIZED(decl_type, value_type) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("prim::min." decl_type "_list(" decl_type \ "[] l, " decl_type "[] r) -> " decl_type "[]"), \ minList, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("prim::max." decl_type "_list(" decl_type \ "[] l, " decl_type "[] r) -> " decl_type \ "[]"), \ maxList, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("prim::min.self_" decl_type "(" decl_type \ "[] self) -> " decl_type), \ listMin, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("prim::max.self_" decl_type "(" decl_type \ "[] self) -> " decl_type), \ listMax, \ aliasAnalysisFromSchema()), - CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("int", int64_t) - CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("float", double) - CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("bool", bool) + CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("int", int64_t) + CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("float", double) + CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("bool", bool) #undef CREATE_COMPARATOR_LIST_OPS_SPECIALIZED // python string is methods return false if empty #define DEFINE_STRING_IS_OP(op_name, char_op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#op_name "(str self) -> bool"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ auto string = pop(stack).toStringRef(); \ push( \ stack, \ @@ -1199,17 +1124,17 @@ RegisterOperators reg( }, \ aliasAnalysisFromSchema()) - DEFINE_STRING_IS_OP(aten::isdigit, ::isdigit), - DEFINE_STRING_IS_OP(aten::isspace, ::isspace), - DEFINE_STRING_IS_OP(aten::isalnum, ::isalnum), - DEFINE_STRING_IS_OP(aten::isalpha, ::isalpha), - DEFINE_STRING_IS_OP(aten::isdecimal, ::isdigit), - DEFINE_STRING_IS_OP(aten::isnumeric, ::isdigit), + DEFINE_STRING_IS_OP(aten::isdigit, ::isdigit), + DEFINE_STRING_IS_OP(aten::isspace, ::isspace), + DEFINE_STRING_IS_OP(aten::isalnum, ::isalnum), + DEFINE_STRING_IS_OP(aten::isalpha, ::isalpha), + DEFINE_STRING_IS_OP(aten::isdecimal, ::isdigit), + DEFINE_STRING_IS_OP(aten::isnumeric, ::isdigit), #define DEFINE_STRING_CHAR_MAP_OP(op_name, char_op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#op_name "(str self) -> str"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ auto string = pop(stack).toStringRef(); \ std::stringstream ss; \ for (char c : string) { \ @@ -1219,28 +1144,135 @@ RegisterOperators reg( }, \ aliasAnalysisFromSchema()) - DEFINE_STRING_CHAR_MAP_OP(aten::upper, ::toupper), - DEFINE_STRING_CHAR_MAP_OP(aten::swapcase, ([](char c) { - if (c == static_cast(::toupper(c))) { - return static_cast(::tolower(c)); - } else { - return static_cast(::toupper(c)); - } - }))}); + DEFINE_STRING_CHAR_MAP_OP(aten::upper, ::toupper), + DEFINE_STRING_CHAR_MAP_OP(aten::swapcase, ([](char c) { + if (c == static_cast(::toupper(c))) { + return static_cast(::tolower(c)); + } else { + return static_cast(::toupper(c)); + } + }))}; + +static std::vector> createOperators( + const OperatorGeneratorArgs* args, + int length) { + std::vector> result; + result.reserve(length); + for (int ii = 0; ii < length; ++ii) { + if (args[ii].schema_str) { + if (args[ii].isOperationCreator) { + result.push_back(OperatorGenerator( + args[ii].schema_str, + args[ii].operationCreator, + args[ii].aliasAnalysis)); + } else { + result.push_back(OperatorGenerator( + args[ii].schema_str, args[ii].operation, args[ii].aliasAnalysis)); + } + } + } + return result; +} + +RegisterOperators reg(([]() { + auto v = createOperators(opGenArgs, sizeof(opGenArgs) / sizeof(opGenArgs[0])); + v.push_back(Operator( + prim::tolist, + // This operator has to be unschematized because the return type + // depends on the type hint and input. The implementation of this + // operator below is intended to be as close to the Python + // implementation in torch/csrc/utils/tensor_list.cpp as possible. + [](const Node* /*node*/) -> Operation { + return [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int elem_ty_val; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int dim_val; + at::Tensor t; + + pop(stack, elem_ty_val); + pop(stack, dim_val); + pop(stack, t); + + // If the Tensor is not on the CPU, transfer it. + if (!t.device().is_cpu()) { + t = t.cpu(); + } -void dictSetItem(Stack* stack) { + // Rebuild the output type using elem_ty_val and dim_val. Start + // with the element type corresponding to elem_ty_val. + TypePtr out_ty; + if (elem_ty_val == 0) { + out_ty = IntType::get(); + } else if (elem_ty_val == 1) { + out_ty = FloatType::get(); + } else if (elem_ty_val == 2) { + out_ty = BoolType::get(); + } else if (elem_ty_val == 3) { + out_ty = ComplexType::get(); + } else { + TORCH_CHECK( + false, + "Unsupported element type for tolist; only int, float, complex and bool are supported"); + } + + // Check that type of the Tensor matches that of the annotation. + // Make an exception for the case in which the annotated type is + // float/complex and the Tensor data type is also float/complex; + // the elements will be casted to double/c10::complex + // later. + TORCH_CHECK( + (out_ty == FloatType::get() && t.is_floating_point()) || + (out_ty == ComplexType::get() && t.is_complex()) || + tryScalarTypeFromJitType(out_ty) == t.scalar_type(), + "Output annotation element type and runtime tensor element type must match for tolist()"); + + // Check that the dimension of the Tensor matches that of the + // annotation. + TORCH_CHECK( + dim_val == t.dim(), + "Output annotation list dimension and runtime tensor dimension must match for tolist()"); + + // Wrap out_ty in a ListType dim times. + for (const auto i : c10::irange(dim_val)) { + (void)i; // Suppress unused variable warning + out_ty = ListType::create(out_ty); + } + + int64_t dim = t.dim(); + auto sizes = t.sizes(); + auto strides = t.strides(); + size_t element_size = t.element_size(); + char* data = static_cast(t.data_ptr()); + auto result = tensorToListRecursive( + data, + 0, + dim, + out_ty, + t.scalar_type(), + sizes, + strides, + element_size); + push(stack, std::move(result)); + }; + }, + aliasAnalysisSpecialCase())); + return v; +})()); + +void dictSetItem(Stack& stack) { auto value = pop(stack); auto idx = pop(stack); auto dict = pop(stack).toGenericDict(); dict.insert_or_assign(std::move(idx), std::move(value)); } -void dictLen(Stack* stack) { +void dictLen(Stack& stack) { auto dict = pop(stack).toGenericDict(); push(stack, int64_t(dict.size())); } -void dictValues(Stack* stack) { +void dictValues(Stack& stack) { auto dict = pop(stack).toGenericDict(); auto values = c10::impl::GenericList(dict.valueType()); for (const auto& entry : dict) { @@ -1249,7 +1281,7 @@ void dictValues(Stack* stack) { push(stack, values); } -void dictKeys(Stack* stack) { +void dictKeys(Stack& stack) { auto dict = pop(stack).toGenericDict(); auto keys = c10::impl::GenericList(dict.keyType()); for (const auto& entry : dict) { @@ -1258,7 +1290,7 @@ void dictKeys(Stack* stack) { push(stack, keys); } -void dictIndex(Stack* stack) { +void dictIndex(Stack& stack) { auto key = pop(stack); auto dict = pop(stack).toGenericDict(); auto value = dict.find(key); @@ -1269,7 +1301,7 @@ void dictIndex(Stack* stack) { } template -void dictGet(Stack* stack) { +void dictGet(Stack& stack) { IValue default_value; if (has_default) { default_value = pop(stack); @@ -1286,7 +1318,7 @@ void dictGet(Stack* stack) { // If the key is in the dict, return it. Else set it to the default value and // return that. -void dictSetDefault(Stack* stack) { +void dictSetDefault(Stack& stack) { auto default_value = pop(stack); auto key = pop(stack); auto dict = pop(stack).toGenericDict(); @@ -1300,7 +1332,7 @@ void dictSetDefault(Stack* stack) { } template -void dictPop(Stack* stack) { +void dictPop(Stack& stack) { IValue default_value; if (has_default) { default_value = pop(stack); @@ -1323,13 +1355,13 @@ void dictPop(Stack* stack) { } } -void dictDelete(Stack* stack) { +void dictDelete(Stack& stack) { dictPop(stack); // pop pushes an item on the stack but delete does not, so get rid of it pop(stack); } -void dictPopItem(Stack* stack) { +void dictPopItem(Stack& stack) { auto dict = pop(stack).toGenericDict(); if (dict.size() == 0) { AT_ERROR("popitem(): dictionary is empty"); @@ -1344,18 +1376,18 @@ void dictPopItem(Stack* stack) { push(stack, tuple); } -void dictContains(Stack* stack) { +void dictContains(Stack& stack) { auto key = pop(stack); auto dict = pop(stack).toGenericDict(); push(stack, dict.contains(key)); } -void dictClear(Stack* stack) { +void dictClear(Stack& stack) { auto dict = pop(stack).toGenericDict(); dict.clear(); } -void dictUpdate(Stack* stack) { +void dictUpdate(Stack& stack) { auto to_add = pop(stack).toGenericDict(); auto dict = pop(stack).toGenericDict(); @@ -1364,7 +1396,7 @@ void dictUpdate(Stack* stack) { } } -void dictItems(Stack* stack) { +void dictItems(Stack& stack) { auto dict = pop(stack).toGenericDict(); auto key_type = dict.keyType(); auto value_type = dict.valueType(); @@ -1377,11 +1409,11 @@ void dictItems(Stack* stack) { push(stack, std::move(items)); } -void dictCopy(Stack* stack) { +void dictCopy(Stack& stack) { push(stack, pop(stack).toGenericDict().copy()); } -void dictConstructFromList(Stack* stack) { +void dictConstructFromList(Stack& stack) { auto input_list = pop(stack); auto list = input_list.toList(); auto tup_type = list.elementType()->expect(); @@ -1396,123 +1428,125 @@ void dictConstructFromList(Stack* stack) { } #define CREATE_DICT_OPS(key_type) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::len.Dict_" key_type "(Dict(" key_type \ ", t) self) -> int"), \ dictLen, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::keys." key_type "(Dict(" key_type \ ", t) self) -> " key_type "[](*)"), \ dictKeys, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::values." key_type "(Dict(" key_type \ ", t) self) -> t[](*)"), \ dictValues, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::__getitem__.Dict_" key_type \ "(Dict(" key_type ", t) self, " key_type \ " key) -> t(*)"), \ dictIndex, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::get." key_type "(Dict(" key_type \ ", t) self, " key_type " key) -> t(*)?"), \ dictGet, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::get.default_" key_type \ "(Dict(" key_type ", t) self, " key_type \ " key, t default_value) -> t(*)"), \ dictGet, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ "aten::setdefault." key_type "(Dict(" key_type \ ", t)(a!) self, " key_type \ "(b -> *) key, t(c -> *) default_value) -> t(*)"), \ dictSetDefault, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::Delete.Dict_" key_type \ "(Dict(" key_type ", t)(a!) self, " key_type \ " key) -> ()"), \ dictDelete, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::pop.Dict_" key_type "(Dict(" key_type \ ", t)(a!) self, " key_type " key) -> t(*)"), \ dictPop, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::pop.Dict_default_" key_type \ "(Dict(" key_type ", t)(a!) self, " key_type \ " key, t default_value) -> t(*)"), \ dictPop, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::popitem." key_type "(Dict(" key_type \ ", t)(a!) self) -> ((" key_type ", t))"), \ dictPopItem, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::clear." key_type "(Dict(" key_type \ ", t)(a!) self) -> ()"), \ dictClear, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::update." key_type "(Dict(" key_type \ ", t)(a!) self, Dict(" key_type \ ", t)(a!) to_add) -> ()"), \ dictUpdate, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::items." key_type "(Dict(" key_type \ ", t) self) -> ((" key_type ", t)[])"), \ dictItems, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::copy.Dict_" key_type "(Dict(" key_type \ ", t)(a) self) -> Dict(" key_type ", t)"), \ dictCopy, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::__contains__." key_type \ "(Dict(" key_type ", t) dict, " key_type \ " key) -> bool"), \ dictContains, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::_set_item." key_type "(Dict(" key_type \ ", t)(a!) l, " key_type \ "(b -> *) idx, t(c -> *) v) -> ()"), \ dictSetItem, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::dict." key_type "((" key_type \ ", tVal)[] inputs) -> Dict(" key_type \ ", tVal)"), \ dictConstructFromList, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::dict.Dict_" key_type "(Dict(" key_type \ ", t)(a) self) -> Dict(" key_type ", t)"), \ dictCopy, \ aliasAnalysisFromSchema()) -RegisterOperators reg_dict_ops({ +static const OperatorGeneratorArgs dict_ops[] = { CREATE_DICT_OPS("str"), CREATE_DICT_OPS("int"), CREATE_DICT_OPS("bool"), CREATE_DICT_OPS("float"), CREATE_DICT_OPS("complex"), CREATE_DICT_OPS("Tensor"), -}); +}; +RegisterOperators reg_dict_ops( + createOperators(dict_ops, sizeof(dict_ops) / sizeof(dict_ops[0]))); // NOLINTNEXTLINE(clang-diagnostic-unused-function) -c10::AliasAnalysisKind aliasAnalysisFromSchema() { +constexpr c10::AliasAnalysisKind aliasAnalysisFromSchema() { return c10::AliasAnalysisKind::FROM_SCHEMA; } @@ -2083,385 +2117,394 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) { }); } +static const OperatorGeneratorArgs opGenArgs1[] = { + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::rangelist(int n) -> int[]"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t n; + pop(stack, n); + c10::List elems; + elems.reserve(n); + for (const auto i : c10::irange(n)) { + elems.push_back(i); + } + push(stack, std::move(elems)); + }, + aliasAnalysisFromSchema()), + // note: this op needs to share a name with the Scalar -> Tensor conversion + // because all _to_tensor conversion have to have the same operator namet + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.bool(bool a) -> Tensor"), + [](Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool b; + pop(stack, b); + push(stack, at::scalar_to_tensor(b)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::device(str a) -> Device"), + [](Stack& stack) { + push(stack, c10::Device(pop(stack).toStringRef())); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"), + [](Stack& stack) { + size_t num_inputs = pop(stack).toInt(); + percentFormat(stack, num_inputs); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), + [](Stack& stack) { + at::Tensor self; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool non_blocking; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool copy; + pop(stack, self, non_blocking, copy); + c10::optional device = c10::nullopt; + c10::optional scalarType = c10::nullopt; + push( + stack, to_dispatch(self, device, scalarType, non_blocking, copy)); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::requires_grad(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.requires_grad()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::grad(Tensor a) -> Tensor(*)"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.grad()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_sparse(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_sparse()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_sparse_csr(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_sparse_csr()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_mkldnn(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_mkldnn()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_mlc(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_mlc()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_vulkan(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_vulkan()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_quantized(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_quantized()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_meta(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_meta()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_ort()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + if (a.name() == "") { + push(stack, IValue()); + } else { + push(stack, a.name()); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::index(Device self) -> int?"), + [](Stack& stack) { + auto d = pop(stack).toDevice(); + if (d.has_index()) { + push(stack, d.index()); + } else { + push(stack, IValue()); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + // TODO return generator object when torchscript supports RNG + // first-class + TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"), + [](Stack& stack) { at::manual_seed(pop(stack).toInt()); }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"), + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.cuda()); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::AutogradZero() -> Tensor"), + [](Stack& stack) { stack.emplace_back(at::Tensor()); }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "prim::ReductionSizes(int[] size, int[] red_axes, bool keepdim = False) -> int[]"), + [](Stack& stack) { + bool keepdim = pop(stack).toBool(); + c10::List axes = pop(stack).toIntList(); + c10::List size = pop(stack).toIntList(); + if (keepdim) { + for (const auto& axis : axes) { + size.set(axis, 1); + } + } else { + int64_t index = 0; + auto iter = size.begin(); + std::sort(axes.begin(), axes.end()); + for (const auto& axis : axes) { + // move iter to the next axis + iter += axis - index; + + // input iter points to axis and is updated to axis + 1 + iter = size.erase(iter); + + // update current index for iter + index = axis + 1; + } + } + push(stack, IValue(std::move(size))); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::BroadcastSizes(...) -> int[]"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + std::vector size; + size.reserve(8); + for (const auto i : c10::irange(num_inputs)) { + size = + at::infer_size(size, peek(stack, i, num_inputs).toIntVector()); + } + drop(stack, num_inputs); + push(stack, IValue(size)); + }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::warn(str message, int stacklevel=2) -> ()"), + [](Stack& stack) { + TORCH_CHECK(false, "warn is implemented directly in the interpreter"); + }, + aliasAnalysisFromSchema()), + + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "onnx::Reshape(Tensor input, Tensor shape) -> Tensor"), + [](Stack& stack) { + at::Tensor input, shape; + pop(stack, input, shape); + shape = shape.contiguous(); + AT_ASSERT(shape.ndimension() == 1); + at::IntArrayRef shape_list(shape.data_ptr(), shape.size(0)); + push(stack, input.reshape(shape_list)); + }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("onnx::Shape(Tensor t) -> Tensor"), + [](Stack& stack) { + auto t = pop(stack).toTensor(); + at::IntArrayRef sizes = t.sizes(); + auto sizes_tensor = torch::empty( + {static_cast(sizes.size())}, at::dtype(at::kLong)); + auto accessor = sizes_tensor.accessor(); + for (const auto i : c10::irange(sizes.size())) { + accessor[i] = sizes[i]; + } + stack.emplace_back(sizes_tensor); + }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::AutogradAnyNonZero(...) -> bool"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + bool result = false; + for (const IValue& v : last(stack, num_inputs)) { + if (v.isTensor()) { + if (v.toTensor().defined()) { + result = true; + break; + } + } else if (v.isTensorList()) { + for (const at::Tensor& t : v.toTensorVector()) { + if (t.defined()) { + result = true; + } + } + if (result) { + break; + } + } else { + TORCH_INTERNAL_ASSERT(false); + } + } + drop(stack, num_inputs); + stack.emplace_back(result); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::AutogradAllZero(...) -> bool"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + bool result = true; + for (const IValue& v : last(stack, num_inputs)) { + TORCH_INTERNAL_ASSERT(v.isTensor()); + if (v.toTensor().defined()) { + result = false; + break; + } + } + drop(stack, num_inputs); + stack.emplace_back(result); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::AutogradAllNonZero(...) -> bool"), + [](Stack& stack) { + auto num_inputs = pop(stack).toInt(); + bool result = true; + for (const IValue& v : last(stack, num_inputs)) { + TORCH_INTERNAL_ASSERT(v.isTensor()); + if (!v.toTensor().defined()) { + result = false; + break; + } + } + drop(stack, num_inputs); + stack.emplace_back(result); + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA("prim::AutogradAdd(Any a, Any b) -> Any"), + [](Stack& stack) { + at::Tensor a, b; + pop(stack, a, b); + // NOLINTNEXTLINE(bugprone-branch-clone) + if (!a.defined() && !b.defined()) { + // undef + undef == undef + stack.emplace_back(a); + } else if (!a.defined()) { + stack.emplace_back(b); + } else if (!b.defined()) { + stack.emplace_back(a); + } else { + stack.emplace_back(a + b); + } + }, + aliasAnalysisSpecialCase()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_size_if_not_equal(int[] self_size, int[] other_size) -> int[]?"), + [](Stack& stack) { + IValue self_size, other_size; + pop(stack, self_size, other_size); + auto s = self_size.toIntVector(); + auto o = other_size.toIntVector(); + if (s == o) { + push(stack, IValue()); + } else { + push(stack, s); + } + }, + aliasAnalysisFromSchema()), + OperatorGeneratorArgs( + TORCH_SELECTIVE_SCHEMA( + "aten::_unwrap_optional(t(a)? optional) -> t(a)"), + [](Stack& stack) { + auto val = pop(stack); + TORCH_CHECK(!val.isNone(), "Unwrapping null optional"); + push(stack, std::move(val)); + }, + aliasAnalysisFromSchema())}; + RegisterOperators reg1( - {OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::rangelist(int n) -> int[]"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t n; - pop(stack, n); - c10::List elems; - elems.reserve(n); - for (const auto i : c10::irange(n)) { - elems.push_back(i); - } - push(stack, std::move(elems)); - }, - aliasAnalysisFromSchema()), - // note: this op needs to share a name with the Scalar -> Tensor conversion - // because all _to_tensor conversion have to have the same operator namet - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.bool(bool a) -> Tensor"), - [](Stack* stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; - pop(stack, b); - push(stack, at::scalar_to_tensor(b)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::device(str a) -> Device"), - [](Stack* stack) { - push(stack, c10::Device(pop(stack).toStringRef())); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"), - [](Stack* stack) { - size_t num_inputs = pop(stack).toInt(); - percentFormat(*stack, num_inputs); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), - [](Stack* stack) { - at::Tensor self; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; - pop(stack, self, non_blocking, copy); - c10::optional device = c10::nullopt; - c10::optional scalarType = c10::nullopt; - push( - stack, - to_dispatch(self, device, scalarType, non_blocking, copy)); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::requires_grad(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.requires_grad()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::grad(Tensor a) -> Tensor(*)"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.grad()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_sparse(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_sparse()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_sparse_csr(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_sparse_csr()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_mkldnn(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_mkldnn()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_mlc(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_mlc()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_vulkan(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_vulkan()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_quantized(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_quantized()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::is_meta(Tensor a) -> bool"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_meta()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - if (a.name() == "") { - push(stack, IValue()); - } else { - push(stack, a.name()); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::index(Device self) -> int?"), - [](Stack* stack) { - auto d = pop(stack).toDevice(); - if (d.has_index()) { - push(stack, d.index()); - } else { - push(stack, IValue()); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - // TODO return generator object when torchscript supports RNG - // first-class - TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"), - [](Stack* stack) { at::manual_seed(pop(stack).toInt()); }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"), - [](Stack* stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.cuda()); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::AutogradZero() -> Tensor"), - [](Stack* stack) { stack->emplace_back(at::Tensor()); }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "prim::ReductionSizes(int[] size, int[] red_axes, bool keepdim = False) -> int[]"), - [](Stack* stack) { - bool keepdim = pop(stack).toBool(); - c10::List axes = pop(stack).toIntList(); - c10::List size = pop(stack).toIntList(); - if (keepdim) { - for (const auto& axis : axes) { - size.set(axis, 1); - } - } else { - int64_t index = 0; - auto iter = size.begin(); - std::sort(axes.begin(), axes.end()); - for (const auto& axis : axes) { - // move iter to the next axis - iter += axis - index; - - // input iter points to axis and is updated to axis + 1 - iter = size.erase(iter); - - // update current index for iter - index = axis + 1; - } - } - push(stack, IValue(std::move(size))); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::BroadcastSizes(...) -> int[]"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - std::vector size; - size.reserve(8); - for (const auto i : c10::irange(num_inputs)) { - size = - at::infer_size(size, peek(stack, i, num_inputs).toIntVector()); - } - drop(stack, num_inputs); - push(stack, IValue(size)); - }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::warn(str message, int stacklevel=2) -> ()"), - [](Stack* stack) { - TORCH_CHECK( - false, "warn is implemented directly in the interpreter"); - }, - aliasAnalysisFromSchema()), - - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "onnx::Reshape(Tensor input, Tensor shape) -> Tensor"), - [](Stack* stack) { - at::Tensor input, shape; - pop(stack, input, shape); - shape = shape.contiguous(); - AT_ASSERT(shape.ndimension() == 1); - at::IntArrayRef shape_list(shape.data_ptr(), shape.size(0)); - push(stack, input.reshape(shape_list)); - }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("onnx::Shape(Tensor t) -> Tensor"), - [](Stack* stack) { - auto t = pop(stack).toTensor(); - at::IntArrayRef sizes = t.sizes(); - auto sizes_tensor = torch::empty( - {static_cast(sizes.size())}, at::dtype(at::kLong)); - auto accessor = sizes_tensor.accessor(); - for (const auto i : c10::irange(sizes.size())) { - accessor[i] = sizes[i]; - } - stack->emplace_back(sizes_tensor); - }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::AutogradAnyNonZero(...) -> bool"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - bool result = false; - for (const IValue& v : last(stack, num_inputs)) { - if (v.isTensor()) { - if (v.toTensor().defined()) { - result = true; - break; - } - } else if (v.isTensorList()) { - for (const at::Tensor& t : v.toTensorVector()) { - if (t.defined()) { - result = true; - } - } - if (result) { - break; - } - } else { - TORCH_INTERNAL_ASSERT(false); - } - } - drop(stack, num_inputs); - stack->emplace_back(result); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::AutogradAllZero(...) -> bool"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - bool result = true; - for (const IValue& v : last(stack, num_inputs)) { - TORCH_INTERNAL_ASSERT(v.isTensor()); - if (v.toTensor().defined()) { - result = false; - break; - } - } - drop(stack, num_inputs); - stack->emplace_back(result); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::AutogradAllNonZero(...) -> bool"), - [](Stack* stack) { - auto num_inputs = pop(stack).toInt(); - bool result = true; - for (const IValue& v : last(stack, num_inputs)) { - TORCH_INTERNAL_ASSERT(v.isTensor()); - if (!v.toTensor().defined()) { - result = false; - break; - } - } - drop(stack, num_inputs); - stack->emplace_back(result); - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA("prim::AutogradAdd(Any a, Any b) -> Any"), - [](Stack* stack) { - at::Tensor a, b; - pop(stack, a, b); - // NOLINTNEXTLINE(bugprone-branch-clone) - if (!a.defined() && !b.defined()) { - // undef + undef == undef - stack->emplace_back(a); - } else if (!a.defined()) { - stack->emplace_back(b); - } else if (!b.defined()) { - stack->emplace_back(a); - } else { - stack->emplace_back(a + b); - } - }, - aliasAnalysisSpecialCase()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::_size_if_not_equal(int[] self_size, int[] other_size) -> int[]?"), - [](Stack* stack) { - IValue self_size, other_size; - pop(stack, self_size, other_size); - auto s = self_size.toIntVector(); - auto o = other_size.toIntVector(); - if (s == o) { - push(stack, IValue()); - } else { - push(stack, s); - } - }, - aliasAnalysisFromSchema()), - OperatorGenerator( - TORCH_SELECTIVE_SCHEMA( - "aten::_unwrap_optional(t(a)? optional) -> t(a)"), - [](Stack* stack) { - auto val = pop(stack); - TORCH_CHECK(!val.isNone(), "Unwrapping null optional"); - push(stack, std::move(val)); - }, - aliasAnalysisFromSchema())}); - -void hashValue(Stack* stack) { + createOperators(opGenArgs1, sizeof(opGenArgs1) / sizeof(opGenArgs1[0]))); + +void hashValue(Stack& stack) { auto value = pop(stack); push(stack, value.hash()); } -RegisterOperators reg2({ +static const OperatorGeneratorArgs opGenArgs2[] = { // registered as Any[] so that heterogenous tuples can be called with len() - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::len.any(Any[] a) -> int"), listLen, aliasAnalysisFromSchema()), // these ops have a specialized implementation for the list element type #define CREATE_SPECIALIZED_LIST_OPS(decl_type, value_type) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ "aten::remove." decl_type "(" decl_type \ "[](a!) self, \ " decl_type " el) -> ()"), \ listRemove, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ "aten::index.list_" decl_type "(" decl_type \ "[] self, \ " decl_type " el) -> int"), \ listIndex, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA( \ "aten::count." decl_type "(" decl_type \ "[] self, \ @@ -2480,102 +2523,102 @@ RegisterOperators reg2({ // `listContains` is not implemented for non-primitive types // TODO: Add List[bool] once .to> doesn't throw an error - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::__contains__.float_list(float[] l, float item) -> bool"), listContains, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sort.int(int[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sort.float(float[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sort.Tensor(Tensor[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sort.bool(bool[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sort.str(str[](a!) self, bool reverse=False) -> ()"), listSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sorted.int(int[](a) input) -> (int[])"), listCopyAndSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sorted.float(float[](a) input) -> (float[])"), listCopyAndSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sorted.Tensor(Tensor[](a) input) -> (Tensor[])"), listCopyAndSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::sorted.bool(bool[](a) input) -> (bool[])"), listCopyAndSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sorted.str(str[](a) input) -> (str[])"), listCopyAndSort, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::eq.float_list(float[] a, float[] b) -> bool"), listEq, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::eq.Tensor_list(Tensor[] a, Tensor[] b) -> bool"), listEq, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::eq.bool_list(bool[] a, bool[] b) -> bool"), listEq, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::eq.str_list(str[] a, str[] b) -> bool"), listEq, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::ne.float_list(float[] a, float[] b) -> bool"), listNe, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::ne.Tensor_list(Tensor[] a, Tensor[] b) -> bool"), listNe, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::ne.bool_list(bool[] a, bool[] b) -> bool"), listNe, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::ne.str_list(str[] a, str[] b) -> bool"), listNe, aliasAnalysisFromSchema()), #define DEFINE_CONVERT_BASE_OP(op_name, prefix, char_op) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA(#op_name "(int i) -> str"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ auto i = pop(stack).toInt(); \ std::stringstream ss; \ if (i < 0) { \ @@ -2590,9 +2633,9 @@ RegisterOperators reg2({ DEFINE_CONVERT_BASE_OP(aten::hex, "x", std::hex), DEFINE_CONVERT_BASE_OP(aten::oct, "o", std::oct), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::bin(int i) -> str"), - [](Stack* stack) { + [](Stack& stack) { auto i = pop(stack).toInt(); std::stringstream ss; if (i == 0) { @@ -2610,10 +2653,10 @@ RegisterOperators reg2({ }, aliasAnalysisFromSchema()), // TODO: deprecate this in favor of aten::getelem - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "prim::StringIndex(str string, int index) -> str"), - [](Stack* stack) { + [](Stack& stack) { auto index = pop(stack).toInt(); auto string = pop(stack).toStringRef(); auto norm_index = normalizeIndex(index, string.size()); @@ -2621,9 +2664,9 @@ RegisterOperators reg2({ push(stack, std::string(&c, 1)); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::chr(int i) -> str"), - [](Stack* stack) { + [](Stack& stack) { auto i = pop(stack).toInt(); std::stringstream ss; TORCH_CHECK( @@ -2639,9 +2682,9 @@ RegisterOperators reg2({ // only used in loop unrolling, not exposed to end users DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::modf(float a) -> (float, float)"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; pop(stack, a); @@ -2651,9 +2694,9 @@ RegisterOperators reg2({ push(stack, b, c); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::frexp(float a) -> (float, int)"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; pop(stack, a); @@ -2665,9 +2708,9 @@ RegisterOperators reg2({ push(stack, m, e); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::ldexp(float x, int i) -> float"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -2765,9 +2808,9 @@ RegisterOperators reg2({ float, float, float), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::abs(Tensor x) -> Tensor"), - [](Stack* stack) { + [](Stack& stack) { at::Tensor x; pop(stack, x); push(stack, x.abs()); @@ -2788,9 +2831,9 @@ RegisterOperators reg2({ std::copysign(a, b), std::copysign(a, b), float), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::_tensor_to_list(Tensor self) -> int[]"), - [](Stack* stack) { + [](Stack& stack) { at::Tensor t; pop(stack, t); c10::List elems; @@ -2801,9 +2844,9 @@ RegisterOperators reg2({ push(stack, std::move(elems)); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::_list_to_tensor(int[] self) -> Tensor"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toIntList(); auto t = torch::empty( {static_cast(l.size())}, at::dtype(at::kInt)); @@ -2813,9 +2856,9 @@ RegisterOperators reg2({ push(stack, std::move(t)); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sum.int(int[] self) -> int"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toIntList(); auto sum = 0; for (const auto& elem : l) { @@ -2824,9 +2867,9 @@ RegisterOperators reg2({ push(stack, sum); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sum.float(float[] self) -> float"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toDoubleList(); auto sum = 0.0; for (const auto& elem : l) { @@ -2835,9 +2878,9 @@ RegisterOperators reg2({ push(stack, sum); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sum.complex(complex[] self) -> complex"), - [](Stack* stack) { + [](Stack& stack) { c10::List> l = pop(stack).toComplexDoubleList(); c10::complex sum = 0.0; for (const auto i : c10::irange(l.size())) { @@ -2846,9 +2889,9 @@ RegisterOperators reg2({ push(stack, sum); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::sum.bool(bool[] self) -> int"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toBoolList(); auto sum = 0; for (const auto& elem : l) { @@ -2859,9 +2902,9 @@ RegisterOperators reg2({ push(stack, sum); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::any.str(str[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { auto l = pop(stack).toList(); for (const auto& elem : l) { if (elem != "") { @@ -2872,9 +2915,9 @@ RegisterOperators reg2({ push(stack, false); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::any.int(int[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toIntList(); for (const auto& elem : l) { if (elem) { @@ -2885,9 +2928,9 @@ RegisterOperators reg2({ push(stack, false); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::any.float(float[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toDoubleList(); for (const auto& elem : l) { if (elem) { @@ -2898,9 +2941,9 @@ RegisterOperators reg2({ push(stack, false); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::any.bool(bool[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toBoolList(); for (const auto& elem : l) { if (elem) { @@ -2911,9 +2954,9 @@ RegisterOperators reg2({ push(stack, false); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::all.int(int[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toIntList(); for (const auto& elem : l) { if (!elem) { @@ -2924,9 +2967,9 @@ RegisterOperators reg2({ push(stack, true); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::all.float(float[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toDoubleList(); for (const auto& elem : l) { if (!elem) { @@ -2937,9 +2980,9 @@ RegisterOperators reg2({ push(stack, true); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::all.bool(bool[] self) -> bool"), - [](Stack* stack) { + [](Stack& stack) { c10::List l = pop(stack).toBoolList(); for (const auto& elem : l) { if (!elem) { @@ -2950,9 +2993,9 @@ RegisterOperators reg2({ push(stack, true); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::divmod.int(int x, int y) -> (int, int)"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t a, b; lldiv_t divresult = {}; @@ -2972,10 +3015,10 @@ RegisterOperators reg2({ static_cast(divresult.rem)); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::divmod.float(float x, float y) -> (float, float)"), - [](Stack* stack) { + [](Stack& stack) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double a, b; pop(stack, a, b); @@ -2990,9 +3033,9 @@ RegisterOperators reg2({ push(stack, (a - rem) / b, rem); }, aliasAnalysisFromSchema()), - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::id(AnyClassType? x) -> int"), - [](Stack* stack) { + [](Stack& stack) { IValue a; pop(stack, a); if (a.isNone()) { @@ -3004,10 +3047,10 @@ RegisterOperators reg2({ aliasAnalysisFromSchema()), #define DEFINE_DIVMOD_MIXED_OP(type_a, type_b) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::divmod." #type_a "_" #type_b "(" #type_a \ " x," #type_b " y) -> (float, float)"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ type_a a; \ type_b b; \ pop(stack, a, b); \ @@ -3024,16 +3067,16 @@ RegisterOperators reg2({ DEFINE_DIVMOD_MIXED_OP(float, int), #undef DEFINE_DIVMOD_MIXED_OP - OperatorGenerator( + OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::hash.generic(t value) -> int"), hashValue, aliasAnalysisFromSchema()), #define DEFINE_COMPLEX_OP(type_a, type_b, actual_type_a, actual_type_b) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_a "_" #type_b "(" #type_a \ " x," #type_b " y) -> complex"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ actual_type_a a; \ actual_type_b b; \ pop(stack, a, b); \ @@ -3044,10 +3087,10 @@ RegisterOperators reg2({ #define DEFINE_COMPLEX_OP_WITH_TENSOR_ARG( \ type_a, type_b, actual_type_a, actual_type_b) \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_a "_" #type_b "(" #type_a \ " x," #type_b " y) -> complex"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ actual_type_a a; \ actual_type_b b; \ pop(stack, a, b); \ @@ -3055,10 +3098,10 @@ RegisterOperators reg2({ push(stack, comp); \ }, \ aliasAnalysisFromSchema()), \ - OperatorGenerator( \ + OperatorGeneratorArgs( \ TORCH_SELECTIVE_SCHEMA("aten::Complex." #type_b "_" #type_a \ "(" #type_b " x," #type_a " y) -> complex"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ actual_type_b a; \ actual_type_a b; \ pop(stack, a, b); \ @@ -3079,7 +3122,10 @@ RegisterOperators reg2({ DEFINE_COMPLEX_OP_WITH_TENSOR_ARG(Tensor, float, at::Tensor, double), DEFINE_COMPLEX_OP_WITH_TENSOR_ARG(Tensor, int, at::Tensor, int), DEFINE_COMPLEX_OP_WITH_TENSOR_ARG(Tensor, bool, at::Tensor, bool), -}); +}; + +RegisterOperators reg2( + createOperators(opGenArgs2, sizeof(opGenArgs2) / sizeof(opGenArgs2[0]))); } // namespace } // namespace jit diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 43c278be474fd..e43c7c052a673 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -31,7 +31,7 @@ RegisterOperators reg( {Operator( prim::profile, [](const Node* node) -> Operation { - return [](Stack* stack) { + return [](Stack& stack) { AT_ERROR( "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT }; @@ -40,7 +40,7 @@ RegisterOperators reg( Operator( prim::profile_ivalue, [](const Node* node) -> Operation { - return [](Stack* stack) { + return [](Stack& stack) { AT_ERROR( "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT }; @@ -50,9 +50,9 @@ RegisterOperators reg( prim::FusionGroup, [](const Node* node) -> Operation { const auto key = registerFusion(node); - return [key](Stack* stack) { + return [key](Stack& stack) { RECORD_FUNCTION("FusionGroup", std::vector()); - runFusion(key, *stack); + runFusion(key, stack); }; }, aliasAnalysisSpecialCase()), @@ -67,7 +67,7 @@ RegisterOperators reg( t->castRaw()->requiresGrad().has_value()); return *t->castRaw()->requiresGrad(); }); - return [rg_props](Stack* stack) { + return [rg_props](Stack& stack) { auto num_inputs = rg_props.size(); // Check every input's shape against profiled (expected) shape. for (const auto i : c10::irange(num_inputs)) { @@ -91,14 +91,14 @@ RegisterOperators reg( auto outputs_used = fmap(node->outputs(), [](const Value* v) { return v->uses().size() > 0; }); - return [=](Stack* stack) { + return [=](Stack& stack) { RECORD_FUNCTION("chunk", last(stack, 1)); at::Tensor t; pop(stack, t); auto result = at::chunk(t, chunks, dim); - stack->insert( - stack->end(), + stack.insert( + stack.end(), std::make_move_iterator(result.begin()), std::make_move_iterator(result.end())); // NB: Chunk can sometimes return a smaller number of outputs. @@ -121,7 +121,7 @@ RegisterOperators reg( num_results); // We know that the output is unused, so it's ok to push // anything on the stack. - stack->emplace_back(); + stack.emplace_back(); } } }; @@ -132,7 +132,7 @@ RegisterOperators reg( [](const Node* node) -> Operation { int64_t raw_dim = node->i(attr::dim); int64_t chunks = node->i(attr::chunks); - return [raw_dim, chunks](Stack* stack) { + return [raw_dim, chunks](Stack& stack) { c10::List shape = pop(stack).toIntList(); c10::List regular_shape = shape.copy(); c10::List last_shape = shape.copy(); @@ -158,7 +158,7 @@ RegisterOperators reg( aliasAnalysisSpecialCase()), Operator( "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)", - [](Stack* stack) { + [](Stack& stack) { RECORD_FUNCTION("_grad_sum_to_size", std::vector()); IValue self, size; pop(stack, self, size); @@ -175,7 +175,7 @@ RegisterOperators reg( OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "prim::ModuleContainerIndex.list(Any self, int ind) -> Any"), - [](Stack* stack) { + [](Stack& stack) { IValue ind = pop(stack); IValue module_dict = pop(stack); std::stringstream ss; @@ -189,7 +189,7 @@ RegisterOperators reg( OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "prim::ModuleContainerIndex.dict(Any self, str ind) -> Any"), - [](Stack* stack) { + [](Stack& stack) { IValue ind = pop(stack); IValue module_dict = pop(stack); push(stack, module_dict.toModule().attr(ind.toStringRef())); @@ -198,7 +198,7 @@ RegisterOperators reg( Operator( prim::TypeCheck /* (...) -> (..., bool) */, [](const Node* /* node */) -> Operation { - return [](Stack* /* stack */) { + return [](Stack& /* stack */) { AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT }; }, @@ -206,7 +206,7 @@ RegisterOperators reg( Operator( prim::FallbackGraph, [](const Node* node) -> Operation { - return [](Stack* stack) { + return [](Stack& stack) { AT_ERROR( "Must be converted to prim::FunctionCall by replaceFallbackGraphWithFallbackFunction"); // NOLINT }; @@ -214,17 +214,17 @@ RegisterOperators reg( aliasAnalysisSpecialCase()), Operator( "prim::Guard(Tensor(a) t) -> Tensor(a)", - [](Stack* stack) { AT_ERROR("Should be replaced by prim::BailOut"); }, + [](Stack& stack) { AT_ERROR("Should be replaced by prim::BailOut"); }, aliasAnalysisFromSchema()), Operator( "prim::BailOut(...) -> Tensor(a)", - [](Stack* /* stack */) { + [](Stack& /* stack */) { AT_ERROR("prim::BailOut not yet implemented"); // NOLINT }, aliasAnalysisFromSchema()), Operator( "prim::BailoutTemplate() -> int", - [](Stack* stack) { + [](Stack& stack) { // TODO: today, we put a single bailout template at the front to // carry the un-optimized graph for bailout nodes to use. Ideally // this should never run, but we haven't written the code to remove @@ -237,7 +237,7 @@ RegisterOperators reg( aliasAnalysisFromSchema()), Operator( "aten::grad(Tensor[] outputs, Tensor[] inputs, Tensor?[]? grad_outputs=None, bool? retain_graph=None, bool create_graph=False, bool allow_unused=False) -> Tensor?[]", - [](Stack* stack) { + [](Stack& stack) { bool allow_unused = pop(stack).toBool(); bool create_graph = pop(stack).toBool(); auto retain_graph = pop(stack).toOptional(); @@ -277,7 +277,7 @@ RegisterOperators reg( // create_graph=True so we use aliasAnalysisConservative for these two OPs Operator( "aten::backward.TensorList(Tensor[] tensors, Tensor?[]? grad_tensors=None, bool? retain_graph=None, bool create_graph=False) -> ()", - [](Stack* stack) { + [](Stack& stack) { bool create_graph = pop(stack).toBool(); auto retain_graph = pop(stack).toOptional(); auto grad_tensors = pop(stack); @@ -298,7 +298,7 @@ RegisterOperators reg( aliasAnalysisConservative()), Operator( "aten::save(t item, str filename) -> ()", - [](Stack* stack) { + [](Stack& stack) { auto filename = pop(stack).toStringRef(); auto ivalue = pop(stack); @@ -312,7 +312,7 @@ RegisterOperators reg( aliasAnalysisFromSchema()), Operator( "prim::IgnoredPythonOp(...) -> None", - [](Stack* stack) { + [](Stack& stack) { throw JITException( "This Python function is annotated to be ignored" " and cannot be and has not been included in the exported" @@ -323,7 +323,7 @@ RegisterOperators reg( aliasAnalysisFromSchema()), Operator( "aten::wait(Future(t) self) -> t", - [](Stack* stack) { + [](Stack& stack) { TORCH_CHECK( false, "wait is implemented directly in the interpreter"); }, @@ -332,7 +332,7 @@ RegisterOperators reg( RegisterOperators logging_operators( {Operator( "prim::AddStatValue(str key, int val) -> ()", - [](Stack* stack) { + [](Stack& stack) { auto val = pop(stack).toInt(); auto key = pop(stack).toString(); @@ -353,7 +353,7 @@ RegisterOperators logging_operators( aliasAnalysisFromSchema()), Operator( "prim::TimePoint() -> int", - [](Stack* stack) { + [](Stack& stack) { auto schema = parseSchema("prim::TimePoint() -> int"); Node* node = nullptr; // TODO: remove this custom tracing code once the custom op bugfix @@ -372,7 +372,7 @@ RegisterOperators logging_operators( }, aliasAnalysisFromSchema())}); -C10_UNUSED void hashValue(Stack* stack) { +C10_UNUSED void hashValue(Stack& stack) { auto value = pop(stack); push(stack, value.hash()); } @@ -453,7 +453,7 @@ bool isSortableListOfObjectsOrTuples( } template -void sort_op(Stack* stack) { +void sort_op(Stack& stack) { bool reverse = has_reverse_arg ? pop(stack).toBool() : false; auto g_list = pop(stack).toList(); @@ -697,7 +697,7 @@ at::Tensor interpolate( ") "); } -void interpolate_op(Stack* stack) { +void interpolate_op(Stack& stack) { at::Tensor input; IValue size; IValue scale_factors; @@ -743,7 +743,7 @@ IValue convert_scale_factor_to_double(const IValue& int_ivalue) { return scale_factor_double; } -void upsample_nearest_op(Stack* stack) { +void upsample_nearest_op(Stack& stack) { at::Tensor input; IValue size; IValue scale_factor_int; @@ -754,7 +754,7 @@ void upsample_nearest_op(Stack* stack) { push(stack, std::move(res)); } -void upsample_op(Stack* stack) { +void upsample_op(Stack& stack) { at::Tensor input; IValue size; IValue scale_factor_int; @@ -772,7 +772,7 @@ void upsample_op(Stack* stack) { push(stack, std::move(res)); } -void upsample_bilinear_op(Stack* stack) { +void upsample_bilinear_op(Stack& stack) { at::Tensor input; IValue size; IValue scale_factor_int; diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index ace87f20b9c35..015d607044ddb 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -184,7 +184,7 @@ void recursiveStore( } template -void createTensorFromList(Stack* stack) { +void createTensorFromList(Stack& stack) { // torch.tensor has a fourth requires_grad arg but torch.as_tensor not, so // we use the template arg to distinguish between these two cases // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -246,7 +246,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"), - [](Stack* stack) { + [](Stack& stack) { RECORD_FUNCTION("split_with_sizes", last(stack, 3)); auto result = at::split_with_sizes( @@ -264,7 +264,7 @@ RegisterOperators reg({ "aten::tensor." #operator_type "(" #operator_type \ " t, *, ScalarType? dtype=None, Device? device=None" \ ", bool requires_grad=False) -> Tensor"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c_type scalar_val; \ IValue dtype; \ IValue device; \ @@ -280,7 +280,7 @@ RegisterOperators reg({ TORCH_SELECTIVE_SCHEMA( \ "aten::as_tensor." #operator_type "(" #operator_type \ " t, *, ScalarType? dtype=None, Device? device=None) -> Tensor"), \ - [](Stack* stack) { \ + [](Stack& stack) { \ c_type scalar_val; \ IValue dtype; \ IValue device; \ @@ -319,7 +319,7 @@ RegisterOperators reg({ // tensor_new.cpp OperatorGenerator( TORCH_SELECTIVE_SCHEMA("aten::_infer_size(int[] a, int[] b) -> int[]"), - [](Stack* stack) { + [](Stack& stack) { auto a = pop(stack); auto b = pop(stack); push(stack, at::infer_size(a.toIntVector(), b.toIntVector())); @@ -328,7 +328,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor"), - [](Stack* stack) { + [](Stack& stack) { at::Tensor weight; at::Tensor input; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -353,7 +353,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(a|b)"), - [](Stack* stack) { + [](Stack& stack) { auto device = pop(stack).toOptional(); auto dtype = pop(stack).toOptional(); at::Tensor data = pop(stack).toTensor(); @@ -377,24 +377,24 @@ RegisterOperators reg({ TORCH_SELECTIVE_SCHEMA( "aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, " "Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)"), - [](Stack* stack) {}, + [](Stack& stack) {}, aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA("aten::_get_tracing_state() -> bool"), - [](Stack* stack) { push(stack, false); }, + [](Stack& stack) { push(stack, false); }, aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"), - [](Stack* stack) { push(stack, true); }, + [](Stack& stack) { push(stack, true); }, aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"), - [](Stack* stack) { push(stack, false); }, + [](Stack& stack) { push(stack, false); }, aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"), - [](Stack* stack) { + [](Stack& stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -410,7 +410,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)"), - [](Stack* stack) { + [](Stack& stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -426,7 +426,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_fill_(Tensor(a!) tensor, float val) -> Tensor(a!)"), - [](Stack* stack) { + [](Stack& stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -440,7 +440,7 @@ RegisterOperators reg({ OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_zero_(Tensor(a!) tensor) -> Tensor(a!)"), - [](Stack* stack) { + [](Stack& stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -451,11 +451,11 @@ RegisterOperators reg({ aliasAnalysisFromSchema()), Operator( "aten::is_grad_enabled() -> bool", - [](Stack* stack) { push(stack, torch::GradMode::is_enabled()); }, + [](Stack& stack) { push(stack, torch::GradMode::is_enabled()); }, aliasAnalysisConservative()), Operator( "aten::set_grad_enabled(bool val) -> ()", - [](Stack* stack) { torch::GradMode::set_enabled(pop(stack).toBool()); }, + [](Stack& stack) { torch::GradMode::set_enabled(pop(stack).toBool()); }, aliasAnalysisConservative()), }); } // namespace diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index b08b59fc6890a..0b41b8e48a345 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -39,7 +39,7 @@ Operation createStaticSubgraphRuntime(const Node* node) { auto g = node->g(attr::Subgraph); auto module = std::make_shared(g); auto num_inputs = module->num_inputs(); - return [module, num_inputs](Stack* stack) { + return [module, num_inputs](Stack& stack) { RECORD_FUNCTION("Static Runtime", std::vector()); auto inps = torch::jit::last(stack, num_inputs); // TODO maybe avoid call to vec @@ -48,10 +48,10 @@ Operation createStaticSubgraphRuntime(const Node* node) { if (module->num_outputs() > 1) { for (auto& o : outputs.toTuple()->elements()) { - push_one(*stack, std::move(o)); + push_one(stack, std::move(o)); } } else { - push_one(*stack, std::move(outputs)); + push_one(stack, std::move(outputs)); } return 0; }; diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index a0c3bac2bbc83..7697613e79573 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -9,16 +9,21 @@ #include #include #include -#include #include #include #include #include +#include #include #include #include #include +#ifdef FBCODE_CAFFE2 +#include +#include +#endif + namespace torch { namespace jit { @@ -66,6 +71,7 @@ void OptimizeGraph( EliminateDeadCode(graph); FuseInferenceOpsForSparseNN(graph); UseVariadicCat(graph); + UseVariadicStack(graph); // TODO: we can avoid this guard by moving operations // to exposed folders. @@ -73,6 +79,7 @@ void OptimizeGraph( if (opts.enable_out_variant) { FuseListUnpack(graph); ReplaceWithCopy(graph); + EnableStaticRuntimeLayerNorm(graph); } #endif ConstantPropagation(graph); @@ -103,8 +110,8 @@ bool mayContainAlias(AliasDb& db, const Value* a, const Value* b) { bool mayContainAlias( AliasDb& db, - const std::unordered_set& a, - const std::unordered_set& b) { + const FastSet& a, + const FastSet& b) { std::vector as; std::vector bs; as.reserve(a.size()); @@ -121,11 +128,11 @@ bool mayContainAlias( } // Get set of all inputs/outputs/constants (always alive) and their aliases -std::unordered_set GetAlwaysAliveValues( +FastSet GetAlwaysAliveValues( const std::shared_ptr& graph, AliasDb& db) { // a set of Values whose live-range exceed current inference - std::unordered_set always_alive; + FastSet always_alive; // mark inputs, constants, outputs as always_alive for (const auto* input : graph->inputs()) { @@ -147,7 +154,7 @@ std::unordered_set GetAlwaysAliveValues( // constants are already in the always_alive set if (node->kind() != prim::Constant) { for (const auto* v : node->outputs()) { - if (mayContainAlias(db, ValueSet{v}, always_alive)) { + if (mayContainAlias(db, {v}, always_alive)) { always_alive.insert(v); } } @@ -157,22 +164,22 @@ std::unordered_set GetAlwaysAliveValues( } // Map each value to all values that are alive at the same time. -using LivenessMap = std::unordered_map>; +using LivenessMap = FastMap>; // The algorithm does a traversal of the execution graph // while keeping track of the live values. LivenessMap GetLivenessMap( const std::shared_ptr& graph, - const std::unordered_set& always_alive, + const FastSet& always_alive, AliasDb& db) { // map a Value to a set of Values that overlap live-ranges with the Value's - std::unordered_map> liveness_map; + FastMap> liveness_map; // map Values to its creation order in graph (Note: only traverse top-level // nodes such that nodes under control-flows are represented by top-level // block nodes) std::vector values_in_creation_order; - std::unordered_map values_to_idx_in_creation_order; + FastMap values_to_idx_in_creation_order; for (const auto* node : graph->nodes()) { for (const auto* v : node->outputs()) { values_to_idx_in_creation_order[v] = values_in_creation_order.size(); @@ -183,10 +190,10 @@ LivenessMap GetLivenessMap( // presence of a Value in live_values_use_chain means the Value alive // Value mapped to set of Nodes that may use the Value (i.e., use-chain of // Value) - std::unordered_map> live_values_use_chain; + FastMap> live_values_use_chain; // Node mapped to set of Values that the Node may use (i.e., def-chain of node // inputs) - std::unordered_map> live_nodes_def_chain; + FastMap> live_nodes_def_chain; // add v to the current liveness_map std::function add_live_value_fn = [&](const Value* v) { @@ -317,16 +324,19 @@ LivenessMap GetLivenessMap( // first: Values that are candidates for memory planning // second: A deterministc order of all values std::pair, std::vector> -GetMemoryPlanningCandidates(const std::shared_ptr& graph) { +GetMemoryPlanningCandidates( + const std::shared_ptr& graph, + const FastMap& node_has_out_variant) { // for determinism - std::unordered_set seen_values; + FastSet seen_values; std::vector all_values; - std::unordered_set can_reuse; + FastSet can_reuse; // values used by unsupported ops (as either inputs or outputs) // these need to be removed from "can_reuse" after analyzing all nodes - std::unordered_set cannot_reuse; + FastSet cannot_reuse; for (auto* n : graph->nodes()) { - bool can_reuse_inputs_outputs = canReuseInputsOutputs(n); + bool can_reuse_inputs_outputs = + canReuseInputsOutputs(n, node_has_out_variant); for (const auto* v : n->inputs()) { if (!seen_values.count(v)) { all_values.emplace_back(v); @@ -387,10 +397,9 @@ GetMemoryPlanningCandidates(const std::shared_ptr& graph) { // // NB: This is a deterministic implementation, which makes it easier to tune // and debug. -std::unordered_map> -GenerateSameStorageValues( +FastMap> GenerateSameStorageValues( const LivenessMap& alive_during, - const std::unordered_set& always_alive, + const FastSet& always_alive, const std::pair, std::vector>& optimizable, AliasDb& db) { @@ -398,8 +407,7 @@ GenerateSameStorageValues( const auto& all_values = optimizable.second; // map Value* to a set Value* that can share the same storage with it - std::unordered_map> - same_storage_values; + FastMap> same_storage_values; // make new_v and old_v map to the same storage (i.e., add to each other's // same_storage_values set) @@ -548,7 +556,7 @@ PrepareForStaticModule( StaticModule::StaticModule( std::shared_ptr g, const StaticModuleOptions& opts) - : StaticModule(PrepareForStaticModule(g, opts), opts) {} + : StaticModule(PrepareForStaticModule(g->copy(), opts), opts) {} StaticModule::StaticModule( const torch::jit::Module& m, @@ -588,9 +596,9 @@ StaticModule::StaticModule( } // map Value* to IValue (from inputs or prim::Constant) or null - std::unordered_map value_to_ivalue; + FastMap value_to_ivalue; // map Value* to its SSA definition IR - std::unordered_map value_to_ssa_def; + FastMap value_to_ssa_def; // N inputs map to the first N entries in storage for (const auto i : c10::irange(graph_->inputs().size())) { @@ -628,6 +636,7 @@ StaticModule::StaticModule( // construct SSA definition for non-constant nodes int node_idx = 0; + FastMap node_has_out_variant; for (Node* node : graph_->nodes()) { if (node->kind() == prim::Constant) { continue; @@ -639,14 +648,22 @@ StaticModule::StaticModule( input_ssa_defs.emplace_back(value_to_ssa_def.at(input)); } node_inputs_ssa_def_map_[node_idx] = input_ssa_defs; - nodes_.emplace_back( - ProcessedNode(node, std::move(ivalue_inputs), opts.enable_out_variant)); + auto pnode = + ProcessedNode(node, std::move(ivalue_inputs), opts.enable_out_variant); + node_has_out_variant.emplace(node, pnode.has_out_variant()); + nodes_.emplace_back(std::move(pnode)); for (const auto i : c10::irange(node->outputs().size())) { value_to_ivalue[node->outputs()[i]] = nullptr; value_to_ssa_def[node->outputs()[i]] = std::make_pair(node_idx, i); } node_idx++; } + for (auto& pnode : nodes_) { + if (pnode.outputs().size() == 1 && + isOptimizableContainerType(pnode.node(), node_has_out_variant)) { + node_is_optimizable_container_type_.emplace(pnode.node()); + } + } for (auto output : graph_->outputs()) { output_ssa_defs_.emplace_back(value_to_ssa_def[output]); } @@ -657,7 +674,7 @@ StaticModule::StaticModule( if (opts_.optimize_memory) { auto lm = GetLivenessMap(graph_, external_values_, alias_db); - auto values = GetMemoryPlanningCandidates(graph_); + auto values = GetMemoryPlanningCandidates(graph_, node_has_out_variant); value_to_same_storage_values_ = GenerateSameStorageValues(lm, external_values_, values, alias_db); } @@ -861,12 +878,30 @@ c10::IValue StaticRuntime::operator()( return std::move(*outputs_[0]); } +namespace { + +std::string generate_node_time_json(const std::string& kind, float millis) { +#ifdef FBCODE_CAFFE2 + folly::dynamic json = folly::dynamic::object(); + json["type"] = kind; + json["metric"] = "latency"; + json["unit"] = "ms"; + json["value"] = millis; + return folly::toJson(json); +#else + return ""; +#endif +} + +} // namespace + void StaticRuntime::benchmark( const std::vector& args, const std::unordered_map& kwargs, const int warmup_runs, const int main_runs, - bool print_per_node_time) { + bool print_per_node_time, + bool generate_ai_pep_output) { float time_per_iter = benchmark_model(args, kwargs, warmup_runs, main_runs); std::cout << "Static runtime ms per iter: " << time_per_iter << ". Iters per second: " << 1000.0 / time_per_iter << std::endl; @@ -897,10 +932,16 @@ void StaticRuntime::benchmark( std::cout << std::setw(15) << ms << " ms. " << std::setw(10) << results.percent_per_node_type[kind] << "%. " << kind << " (" << results.instances_per_node_type[kind] << " nodes"; - if (results.out_nodes.count(kind) == 0) { - std::cout << ")" << std::endl; - } else { + if (results.out_nodes.count(kind)) { std::cout << ", out variant)" << std::endl; + } else if (results.native_nodes.count(kind)) { + std::cout << ", native)" << std::endl; + } else { + std::cout << ")" << std::endl; + } + + if (generate_ai_pep_output) { + LOG(INFO) << "PyTorchObserver " << generate_node_time_json(kind, ms); } } std::cout << std::setw(15) << results.total_time << " ms. in Total" @@ -1136,6 +1177,8 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops( if (nodes_[i].has_out_variant()) { results.out_nodes.insert(kind); results.out_nodes_count++; + } else if (nodes_[i].has_native()) { + results.native_nodes.insert(kind); } results.total_time += results.time_per_node[i]; } @@ -1160,8 +1203,7 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) { TORCH_CHECK(inputs_[i].isNone(), "Input ", i, " was not cleaned up"); } - std::unordered_set output_ivalues( - outputs_.begin(), outputs_.end()); + FastSet output_ivalues(outputs_.begin(), outputs_.end()); for (const auto n : c10::irange(nodes_.size())) { auto& pnode = nodes_[n]; for (const auto i : c10::irange(pnode.outputs().size())) { @@ -1174,7 +1216,8 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) { // check for intermediates if (!ival->isNone()) { TORCH_CHECK( - ival->isTensor() || isOptimizableContainerType(pnode.node()), + ival->isTensor() || + static_module_.is_optimizable_container_type(pnode.node()), error_msg); if (ival->isTensor()) { const auto& t = ival->toTensor(); @@ -1197,13 +1240,13 @@ void StaticRuntime::check_for_memory_leak(bool output_returned) { static void assign_storage_to_managed_tensors( StaticRuntime* runtime, - const std::unordered_set& managed_tensor_values, - const std::unordered_map>& + const FastSet& managed_tensor_values, + const FastMap>& value_to_same_storage_values, std::vector>>& managed_tensors) { // map Value to index to managed_storage, where multiple values can // map to the same index (i.e., sharing the same storage) - std::unordered_map value_to_storage_idx; + FastMap value_to_storage_idx; // Snapshot of the current memory state for (auto& pnode : runtime->nodes()) { @@ -1213,19 +1256,21 @@ static void assign_storage_to_managed_tensors( if (managed_tensor_values.count(val)) { TORCH_CHECK(ival.isTensor()); at::Tensor* tensor = &ival.toTensor(); - - if (value_to_storage_idx.count(val)) { - managed_tensors[value_to_storage_idx[val]].second.emplace_back( - tensor); + auto f = value_to_storage_idx.find(val); + if (f != value_to_storage_idx.end()) { + auto storage_idx = f->second; + managed_tensors[storage_idx].second.emplace_back(tensor); } else { auto p = std::make_pair>(0, {tensor}); managed_tensors.emplace_back(std::move(p)); // first of a group, update the value_to_storage_idx map with the // index - if (value_to_same_storage_values.count(val)) { + auto f = value_to_same_storage_values.find(val); + if (f != value_to_same_storage_values.end()) { auto storage_idx = managed_tensors.size() - 1; - for (const auto* v : value_to_same_storage_values.at(val)) { + const auto& same_storage_values = f->second; + for (const auto* v : same_storage_values) { value_to_storage_idx[v] = storage_idx; } } @@ -1237,14 +1282,14 @@ static void assign_storage_to_managed_tensors( MemoryPlanner::MemoryPlanner( StaticRuntime* runtime, - const std::unordered_map>& + const FastMap>& value_to_same_storage_values, - const std::unordered_set& external_values, + const FastSet& external_values, bool enable_out_variant, bool manage_graph_output_memory) { // collect register indices of outputs of ops with out variant - std::unordered_set managed_tensor_values; - std::unordered_set leaked_values; + FastSet managed_tensor_values; + FastSet leaked_values; if (enable_out_variant) { for (ProcessedNode& pnode : runtime->nodes()) { if (pnode.has_out_variant()) { @@ -1255,11 +1300,11 @@ MemoryPlanner::MemoryPlanner( } // Types are stored in the underlying TorchScript IR const auto& type = out_v->type(); - if (type->cast()) { + if (type->castRaw()) { managed_tensor_values.insert(out_v); - } else if (isOptimizableContainerType(pnode.node())) { - // We "leak" certain container types because their allocations take - // a long time + } else if (runtime->is_optimizable_container_type(pnode.node())) { + // We "leak" certain container types because their allocations + // take a long time leaked_values.insert(out_v); } } @@ -1268,7 +1313,7 @@ MemoryPlanner::MemoryPlanner( } // collect unmanaged output ivalues - std::unordered_set unmanaged_ivalues; + FastSet unmanaged_ivalues; for (ProcessedNode& pnode : runtime->nodes()) { for (const auto i : c10::irange(pnode.outputs().size())) { // Types are stored in the underlying TorchScript IR @@ -1290,9 +1335,11 @@ MemoryPlanner::MemoryPlanner( } // copy to unmanaged_ivalues_ - for (IValue* out : unmanaged_ivalues) { - unmanaged_ivalues_.emplace_back(out); - } + unmanaged_ivalues_.reserve(unmanaged_ivalues.size()); + unmanaged_ivalues_.insert( + unmanaged_ivalues_.begin(), + unmanaged_ivalues.begin(), + unmanaged_ivalues.end()); if (enable_out_variant) { ::torch::jit::assign_storage_to_managed_tensors( @@ -1420,7 +1467,7 @@ void ProcessedNode::run() { } DCHECK(op_); - op_->operator()(&stack); + op_->operator()(stack); DCHECK_EQ(stack.size(), node_->outputs().size()); for (const auto i : c10::irange(node_->outputs().size())) { diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index cc36df037b02d..0d2378760f270 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -9,9 +9,26 @@ #include #include +#ifdef FBCODE_CAFFE2 +#include +#include +#endif + namespace torch { namespace jit { +#ifdef FBCODE_CAFFE2 +template +using FastMap = folly::F14FastMap; +template +using FastSet = folly::F14FastSet; +#else +template +using FastMap = std::unordered_map; +template +using FastSet = std::unordered_set; +#endif + TORCH_API bool canEnableStaticRuntime( const std::shared_ptr& graph); @@ -127,7 +144,7 @@ class TORCH_API StaticModule { size_t num_inputs() const; size_t num_outputs() const; - const std::unordered_map>& index_map() const { + const FastMap>& index_map() const { return node_inputs_ssa_def_map_; } @@ -143,16 +160,21 @@ class TORCH_API StaticModule { return nodes_; } + bool is_optimizable_container_type(Node* n) const { + auto it = node_is_optimizable_container_type_.find(n); + return it != node_is_optimizable_container_type_.end(); + } + const c10::optional& schema() const { return schema_; } - const std::unordered_map>& + const FastMap>& values_share_same_storage() const { return value_to_same_storage_values_; } - const std::unordered_set& external_values() const { + const FastSet& external_values() const { return external_values_; } @@ -178,15 +200,17 @@ class TORCH_API StaticModule { // a vector of ssa_defs corresponding to graph->outputs() std::vector output_ssa_defs_; // map a node idx (in graph order) to a vector of ssa_defs for node inputs - std::unordered_map> node_inputs_ssa_def_map_; + FastMap> node_inputs_ssa_def_map_; // Bookkeeping for MemoryPlanner in StaticRuntime // values whose live-time exceeds that of running one inference (e.g., input, // output, prim::Constants, and their aliases) - std::unordered_set external_values_; + FastSet external_values_; // map a value to the set of values that may share the same storage with it - std::unordered_map> + FastMap> value_to_same_storage_values_; + + FastSet node_is_optimizable_container_type_; }; class TORCH_API StaticRuntime { @@ -210,7 +234,8 @@ class TORCH_API StaticRuntime { const std::unordered_map& kwargs, const int warmup_runs, const int main_runs, - bool print_per_node_time = false); + bool print_per_node_time = false, + bool generate_ai_pep_output = false); float benchmark_model( const std::vector& args, @@ -231,6 +256,7 @@ class TORCH_API StaticRuntime { std::unordered_map percent_per_node_type; std::unordered_map instances_per_node_type; std::unordered_set out_nodes; + std::unordered_set native_nodes; }; IndividualMetrics benchmark_individual_ops( @@ -269,6 +295,10 @@ class TORCH_API StaticRuntime { void check_for_memory_leak(bool output_returned = true); + bool is_optimizable_container_type(Node* n) const { + return static_module_.is_optimizable_container_type(n); + } + private: // helper method for copying input args/kwargs into inputs_ void set_inputs( @@ -322,8 +352,8 @@ class MemoryPlanner { public: explicit MemoryPlanner( StaticRuntime* runtime, - const std::unordered_map>&, - const std::unordered_set& external_values, + const FastMap>&, + const FastSet& external_values, bool enable_out_variant, bool manage_graph_output_memory); // disable copying and moving @@ -410,6 +440,10 @@ class TORCH_API ProcessedNode { return static_cast(fn_); } + bool has_native() const { + return static_cast(native_fn_); + } + bool verify_outputs_not_overlapping_with_immutable_inputs() const; private: diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index d84b1cd8b28d2..7a1558dd70a00 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace torch { @@ -100,17 +101,25 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR( if (n->inputs().size() != 2) { return nullptr; } - // TODO: make __getitem__ work for other container types - if (n->input(0)->type()->castRaw() == nullptr) { - return nullptr; + + if (n->input(0)->type()->castRaw()) { + return [](ProcessedNode* p_node) { + auto dict = p_node->Input(0).toGenericDict(); + auto key = p_node->Input(1); + auto value = dict.find(key); + TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); + p_node->Output(0) = value->value(); + }; + } else if (n->input(0)->type()->castRaw()) { + return [](ProcessedNode* p_node) { + auto list = p_node->Input(0).toList(); + auto idx = p_node->Input(1).toInt(); + p_node->Output(0) = getItem(list, idx); + }; } - return [](ProcessedNode* p_node) { - auto dict = p_node->Input(0).toGenericDict(); - auto key = p_node->Input(1); - auto value = dict.find(key); - TORCH_CHECK(value != dict.end(), "Key not in dict: ", key); - p_node->Output(0) = value->value(); - }; + + // TODO(T98581096): make __getitem__ work for other container types + return nullptr; }); REGISTER_NATIVE_OPERATOR_FUNCTOR( @@ -346,6 +355,37 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator { }; }); +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::detach, + aten_detach, + [](Node* n) -> SROperator { + if (!n->matches( + torch::schema("aten::detach(Tensor(a) self) -> Tensor(a)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in0_t = p_node->Input(0).toTensor(); + p_node->Output(0) = at::native::alias(in0_t); + }; + }); + +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::expand_as, + aten_expand_as, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& self = p_node->Input(0).toTensor(); + const auto& other = p_node->Input(1).toTensor(); + p_node->Output(0) = self.expand(other.sizes()); + }; + }); + REGISTER_NATIVE_OPERATOR_FUNCTOR( prim::isinstance, prim_isinstance, diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index eef5595cee7b2..62f5bb28c1553 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -9,18 +9,24 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include +#include +#include #include #include #include #include #include #include +#include C10_DEFINE_bool( static_runtime_enable_fast_math, @@ -177,6 +183,94 @@ Tensor& linear_out( return output; } +Tensor& c2_argmin_out( + Tensor& output, + const Tensor& input, + const int64_t dim, + const bool keepdim) { + const auto ndim = input.dim(); + int64_t dim_ = maybe_wrap_dim(dim, ndim); + TORCH_CHECK(dim_ >= 0 && dim_ < ndim); + + const auto in_dims = input.sizes(); + + c10::SmallVector out_dims; + out_dims.reserve(ndim); + int prev_size = 1; + int next_size = 1; + for (int i = 0; i < dim_; ++i) { + out_dims.push_back(in_dims[i]); + prev_size *= in_dims[i]; + } + if (keepdim) { + out_dims.push_back(1); + } + for (auto i = dim_ + 1; i < ndim; ++i) { + out_dims.push_back(in_dims[i]); + next_size *= in_dims[i]; + } + at::native::resize_(output, out_dims, c10::nullopt); + + const auto n = in_dims[dim_]; + + if (next_size == 1) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto in_ptr = input.data_ptr(); + const auto out_ptr = output.data_ptr(); + // input is a [prev_size, n] tensor. + // output is a [prev_size,] tensor. + // Thus, access is contiguous/coalesced. + for (int i = 0; i < prev_size; ++i) { + auto v = std::min_element( + in_ptr + i * n, + in_ptr + (i + 1) * n, + [](scalar_t a, scalar_t b) { + // if a is nan, then a is *less* than b with LessOrNan + // semantics + if (at::_isnan(a)) { + return true; + } + // if a is not nan and b is nan, then a is not less than b + // with LessOrNan semantics otherwise, act normally. If `b` is + // NaN then a < b will always return false, so this is + // equivalent to the first snippet. + return a < b; + }); + out_ptr[i] = std::distance(in_ptr + i * n, v); + } + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto less_or_nan = native::detail::LessOrNan{}; + + const auto in_ptr = input.data_ptr(); + const auto out_ptr = output.data_ptr(); + + std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t)); + + for (int i = 0; i < prev_size; ++i) { + const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size; + for (int k = 1; k < n; ++k) { + for (int j = 0; j < next_size; ++j) { + int64_t* cur_out_ptr = out_ptr + i * next_size + j; + if (less_or_nan( + *cur_in_ptr, + in_ptr + [i * n * next_size + *cur_out_ptr * next_size + j], + *cur_out_ptr, + k)) { + *cur_out_ptr = k; + } + ++cur_in_ptr; + } + } + } + }); + } + return output; +} } // namespace native } // namespace at @@ -198,7 +292,7 @@ bool disableUnsafeMathOp(const char* op_name) { // not guarantee bit exactness vs the jit interpreter. Note aten::relu is not // included even though it uses NNC because the results of relu should always // match. - static const std::unordered_set fast_ops{ + static const FastSet fast_ops{ "aten::add", "aten::tanh", "aten::sigmoid", "aten::logit"}; return fast_ops.count(op_name) > 0; } @@ -214,33 +308,39 @@ std::function getOutOfPlaceOperation(Node* n) { // Returns true if the node represents an op with variadic arguments. bool hasVarArgs(Node* n) { - if (n->kind() == prim::VarConcat) { + if (n->kind() == prim::VarConcat || n->kind() == prim::VarStack) { return true; } return false; } -// Expensive check, use sparingly. -// This is needed to make sure that we only switch to out variants for the -// supported overloads, which is checked in the `Generate` step in -// `SROperatorRegistry()->Create(op_name)->Generate(n)` -bool canReuseInputsOutputs(Node* n) { +bool canReuseInputsOutputs( + Node* n, + const FastMap& node_has_out_variant) { + auto it = node_has_out_variant.find(n); + if (it != node_has_out_variant.end()) { + return it->second; + } return getOutOfPlaceOperation(n) != nullptr; } // returns true if the producers of the inputs // to this operations are out of place. // This means the IValues will not change run to run -bool inputsCanRunOutOfPlace(Node* n) { +bool inputsCanRunOutOfPlace( + Node* n, + const FastMap& node_has_out_variant) { for (auto* input : n->inputs()) { - if (!canReuseInputsOutputs(input->node())) { + if (!canReuseInputsOutputs(input->node(), node_has_out_variant)) { return false; } } return true; } -bool isOptimizableContainerType(Node* n) { +bool isOptimizableContainerType( + Node* n, + const FastMap& node_has_out_variant) { const auto& type = n->output()->type(); bool is_supported_type = false; if (type->kind() == TypeKind::ListType) { @@ -256,7 +356,7 @@ bool isOptimizableContainerType(Node* n) { }); is_supported_type = iter != types.end(); } - return is_supported_type && inputsCanRunOutOfPlace(n); + return is_supported_type && inputsCanRunOutOfPlace(n, node_has_out_variant); } REGISTER_OPERATOR_FUNCTOR( @@ -264,7 +364,7 @@ REGISTER_OPERATOR_FUNCTOR( prim_ListConstruct, [](Node* n) -> SROperator { const auto& type = n->output()->type()->expectRef(); - bool can_optimize = isOptimizableContainerType(n); + bool can_optimize = isOptimizableContainerType(n, FastMap()); return [can_optimize, &type](ProcessedNode* p_node) { const auto& out_l = p_node->Output(0); if (!out_l.isNone() && can_optimize) { @@ -284,7 +384,7 @@ REGISTER_OPERATOR_FUNCTOR( prim::TupleConstruct, prim_TupleConstruct, [](Node* n) -> SROperator { - bool can_optimize = isOptimizableContainerType(n); + bool can_optimize = isOptimizableContainerType(n, FastMap()); return [can_optimize](ProcessedNode* p_node) { const auto& out_l = p_node->Output(0); if (!out_l.isNone() && can_optimize) { @@ -301,6 +401,23 @@ REGISTER_OPERATOR_FUNCTOR( }; }); +REGISTER_OPERATOR_FUNCTOR(aten::abs, aten_abs, [](Node* n) -> SROperator { + if (!n->matches(torch::schema("aten::abs(Tensor self) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in0_t = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::abs(in0_t); + } else { + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + at::native::abs_out(in0_t, out_t); + } + }; +}); + REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"))) { @@ -435,6 +552,29 @@ SROperator aten_stack(Node* n) { REGISTER_OPERATOR_FUNCTOR(aten::stack, aten_stack, aten_stack); +REGISTER_OPERATOR_FUNCTOR( + prim::VarStack, + prim_VarStack, + [](Node* n) -> SROperator { + return [](ProcessedNode* p_node) { + const size_t num_inputs = p_node->inputs().size(); + + std::vector inputs(num_inputs - 1); + for (size_t i = 0; i < num_inputs - 1; ++i) { + inputs[i] = p_node->Input(i).toTensor(); + } + + const auto dim = p_node->Input(num_inputs - 1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::_stack_cpu(inputs, dim); + } else { + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + at::native::_stack_out_cpu(inputs, dim, out_t); + } + }; + }); + REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"))) { @@ -453,202 +593,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROp }; }); -namespace { - -// Use the width of an AVX-512 vector by default; this happens to work OK for -// AVX2 as well. Some ops benefit from using multiple AVX ports, in which case -// they are vectorized by twice this constant. An exception is logit, since it -// contains FP divide, which is single-ported. -static constexpr int kVectorWidth = 16; - -#ifdef TORCH_ENABLE_LLVM - -struct TEWrapper { - tensorexpr::KernelArena ka; - tensorexpr::KernelScope ks; - std::unique_ptr cg; - TEWrapper() = default; - void update(std::unique_ptr&& cg_) { - cg = std::move(cg_); - } - - void call(const std::vector& args) { - cg->call_raw(args); - } - - inline bool supports(const at::Tensor& t) { - return t.is_contiguous() && t.dtype().Match(); - } -}; - -void optimizePointwise( - tensorexpr::LoopNest* ln, - tensorexpr::Tensor* target, - int width) { - using namespace torch::jit::tensorexpr; - std::vector loops = ln->getLoopStmtsFor(target); - ForPtr inner, tail; - TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); - ln->splitWithTail(loops[0], width, &inner, &tail); - ln->vectorize(inner); -} - -std::shared_ptr wrapTECompute( - std::shared_ptr wrap, - tensorexpr::Placeholder& in, - tensorexpr::Tensor* out, - tensorexpr::VarHandle& dim, - int width = kVectorWidth) { - using namespace torch::jit::tensorexpr; - LoopNest ln({out}); - optimizePointwise(&ln, out, width); - ln.prepareForCodegen(); - StmtPtr s = ln.root_stmt(); - s = tensorexpr::IRSimplifier::simplify(s); - std::vector args; - args.emplace_back(out); - args.emplace_back(in); - args.emplace_back(dim); - auto cg = std::make_unique(s, args); - wrap->update(std::move(cg)); - return wrap; -}; - -#else - -struct TEWrapper { - tensorexpr::KernelArena ka; - tensorexpr::KernelScope ks; - TEWrapper() = default; - template - void operator()(const Ts&... ts) { - DCHECK(0 && "Invalid call"); - } - void call(const std::vector& args) { - DCHECK(0 && "Invalid call"); - } - - inline bool supports(const at::Tensor& t) { - return false; - } -}; - -std::shared_ptr wrapTECompute( - std::shared_ptr wrap, - tensorexpr::Placeholder& in, - tensorexpr::Tensor* out, - tensorexpr::VarHandle& dim, - int width = kVectorWidth) { - return wrap; -}; - -#endif - -std::mutex& getNNCCacheMutex() { - static std::mutex nncCacheMutex; - return nncCacheMutex; -} - -std::unordered_map>& getNNCCache() { - static std::unordered_map> nncCache; - return nncCache; -} - -std::shared_ptr lookupNNCCache(NodeKind kind) { - std::lock_guard lock(getNNCCacheMutex()); - auto it = getNNCCache().find(kind); - if (it != getNNCCache().end()) { - return it->second; - } - return nullptr; -} - -void updateNNCCache(NodeKind kind, std::shared_ptr code) { - std::lock_guard lock(getNNCCacheMutex()); - getNNCCache()[kind] = code; -} - -} // namespace - -std::shared_ptr createLogit(c10::optional clamp) { - using namespace torch::jit::tensorexpr; - // TODO: Use NNC cache for this op. - auto wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { - auto A_elem = [&]() { - if (!clamp) { - return A.load(i); - } else { - auto elem = A.load(i); - auto min = FloatImm::make(*clamp); - auto max = FloatImm::make(1.0f - *clamp); - elem = CompareSelect::make(elem, min, min, elem, kLT); - return CompareSelect::make(elem, max, max, elem, kGT); - } - }(); - return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem)); - }); - return wrapTECompute(wrap, A, B, N); -} - -std::shared_ptr createRelu() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::relu); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { - auto zero = FloatImm::make(0.f); - auto a = A.load(i); - return ifThenElse(a < zero, zero, a); - }); - wrap = wrapTECompute(wrap, A, B, N); - updateNNCCache(aten::relu, wrap); - return wrap; -} - -std::shared_ptr createTanh() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::tanh); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { - auto a = A.load(i); - return fast_tanh(a); - }); - wrap = wrapTECompute(wrap, A, B, N); - updateNNCCache(aten::tanh, wrap); - return wrap; -} - -std::shared_ptr createSigmoid() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::sigmoid); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - Tensor* B = - Compute("B", {N}, [&](const VarHandle& i) { return sigmoid(A.load(i)); }); - // NNC uses sleef for vectorizing sigmoid, which comes in an 8-wide flavor - // (Sleef_expf8). - constexpr int kSleefWidth = 8; - wrap = wrapTECompute(wrap, A, B, N, kSleefWidth); - updateNNCCache(aten::sigmoid, wrap); - return wrap; -} - REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { if (!n->matches(torch::schema("aten::relu(Tensor self) -> Tensor"))) { LogAndDumpSchema(n); @@ -734,8 +678,9 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { ? c10::make_optional(static_cast(clamp_d.value())) : c10::nullopt; } - auto te = clamp ? createLogit(clamp) : nullptr; - return [te](ProcessedNode* p_node) { + auto te = clamp ? createLogit() : nullptr; + float clamp_value = clamp ? *clamp : 0.0f; + return [te, clamp_value](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); if (p_node->Output(0).isNone()) { p_node->Output(0) = create_empty_from(in0_t); @@ -749,16 +694,18 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { } else { at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); int64_t nn = in0_t.numel(); - te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); + float c = clamp_value; + te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c}); } }; }); +// TODO(T98923825): Uncomment this once the bug in this gets fixed. +/* REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { if (!n->matches(torch::schema( - "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"))) { - LogAndDumpSchema(n); - return nullptr; + "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> +Tensor"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { const auto& src = p_node->Input(0).toTensor(); @@ -784,6 +731,8 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { at::native::copy_(out_t, src, false); }; }); +*/ + REGISTER_OPERATOR_FUNCTOR( quantized::embedding_bag_byte_rowwise_offsets, quantized_embedding_bag_byte_rowwise_offsets, @@ -821,6 +770,7 @@ REGISTER_OPERATOR_FUNCTOR( include_last_offset); }; }); + REGISTER_OPERATOR_FUNCTOR( quantized::embedding_bag_4bit_rowwise_offsets, embedding_bag_4bit_rowwise_offsets, @@ -859,6 +809,27 @@ REGISTER_OPERATOR_FUNCTOR( }; }); +REGISTER_OPERATOR_FUNCTOR( + quantized::embedding_bag_byte_prepack, + embedding_bag_byte_prepack, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& weight = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::qembeddingbag_byte_prepack(weight); + return; + } + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + at::native::qembeddingbag_byte_prepack_out(out_t, weight); + }; + }); + // The out variant takes precedence over native REGISTER_OPERATOR_FUNCTOR(aten::narrow_copy, aten_narrow_copy, [](Node* n) -> SROperator { if (!n->matches(torch::schema( @@ -1359,61 +1330,109 @@ REGISTER_OPERATOR_FUNCTOR(aten::argmin, aten_argmin, [](Node* n) -> SROperator { } else { auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); + if (in0_t.is_contiguous() && dim.has_value()) { + at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim); + return; + } at::cpu::argmin_out(out_t, in0_t, dim, keepdim); } }; }); -REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROperator { +REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator { if (!n->matches(torch::schema( - "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) { + "aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) { LogAndDumpSchema(n); return nullptr; } return [](ProcessedNode* p_node) { - // ignore Input(5): `bool cudnn_enable=True` - const auto& input = p_node->Input(0).toTensor(); - const auto normalized_shape = p_node->Input(1).toIntVector(); - auto weight_opt = p_node->Input(2).toOptional(); - auto bias_opt = p_node->Input(3).toOptional(); - float eps = p_node->Input(4).toDouble(); - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const at::Tensor& weight = *weight_maybe_owned; - c10::MaybeOwned bias_maybe_owned = - at::borrow_from_optional_tensor(bias_opt); - const at::Tensor& bias = *bias_maybe_owned; - - auto M_N = at::native::_check_layer_norm_inputs( - input, normalized_shape, weight, bias); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - auto beta = bias.expect_contiguous(); - + const auto& in_t = p_node->Input(0).toTensor(); + const auto& dim = p_node->Input(1).toInt(); + const auto& dtype = p_node->Input(2).toOptional(); if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::empty_like( - *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - c10::nullopt /* device */, - c10::nullopt /* pin_memory */, - at::MemoryFormat::Contiguous); + p_node->Output(0) = at::native::softmax(in_t, dim, dtype); } else { - at::native::resize_( - p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); - } - at::Tensor& output = p_node->Output(0).toTensor(); - at::Tensor mean = create_empty_from({M}, *X); - at::Tensor rstd = create_empty_from({M}, *X); + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); - at::native::layer_norm_cpu_out( - output, mean, rstd, input, normalized_shape, *gamma, *beta, eps, M, N); + auto half_to_float = in_t.scalar_type() == at::ScalarType::Half && + dtype == at::ScalarType::Float; + at::cpu::_softmax_out(out_t, in_t, dim, half_to_float); + } }; }); +REGISTER_OPERATOR_FUNCTOR( + static_runtime::layer_norm, + aten_layer_norm, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor,Tensor,Tensor)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + // ignore Input(5): `bool cudnn_enable=True` + const auto& input = p_node->Input(0).toTensor(); + const auto normalized_shape = p_node->Input(1).toIntVector(); + auto weight_opt = p_node->Input(2).toOptional(); + auto bias_opt = p_node->Input(3).toOptional(); + float eps = p_node->Input(4).toDouble(); + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const at::Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const at::Tensor& bias = *bias_maybe_owned; + + auto M_N = at::native::_check_layer_norm_inputs( + input, normalized_shape, weight, bias); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + auto beta = bias.expect_contiguous(); + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::empty_like( + *X, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); + } else { + at::native::resize_( + p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); + } + if (p_node->Output(1).isNone()) { + p_node->Output(1) = create_empty_from({M}, *X); + } else { + at::native::resize_(p_node->Output(1).toTensor(), {M}, c10::nullopt); + } + if (p_node->Output(2).isNone()) { + p_node->Output(2) = create_empty_from({M}, *X); + } else { + at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt); + } + at::Tensor& output = p_node->Output(0).toTensor(); + at::Tensor mean = p_node->Output(1).toTensor(); + at::Tensor rstd = p_node->Output(2).toTensor(); + at::native::layer_norm_cpu_out( + output, + mean, + rstd, + input, + normalized_shape, + *gamma, + *beta, + eps, + M, + N); + }; + }); + REGISTER_OPERATOR_FUNCTOR(aten::norm, aten_norm, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor")) && @@ -1528,6 +1547,53 @@ REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SR }; }); +REGISTER_OPERATOR_FUNCTOR( + fb::quantized_linear, + fb_quantized_linear, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "fb::quantized_linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase w_prepack, Tensor Y_scale_i, Tensor Y_zero_point_i) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + const auto w = toIValue(n->inputs()[1]); + c10::intrusive_ptr packed_weight; + if (w) { + packed_weight = w->toCustomClass(); + } + return [packed_weight](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + const auto output_scale = p_node->Input(2).toTensor().item().toFloat(); + const auto output_zero_point = + p_node->Input(3).toTensor().item().toLong(); + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::empty_affine_quantized( + {0}, + c10::kQUInt8, + c10::nullopt, + c10::kCPU, + false, + output_scale, + output_zero_point, + c10::nullopt); + } + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + + if (packed_weight) { + packed_weight->apply_out( + input, output_scale, output_zero_point, out_t); + } else { + // Weights could be quantized on the fly + auto packed_weight_tmp = + p_node->Input(1).toCustomClass(); + packed_weight_tmp->apply_out( + input, output_scale, output_zero_point, out_t); + } + }; + }); + REGISTER_OPERATOR_FUNCTOR(aten::full, aten_full, [](Node* n) -> SROperator { if (!n->matches(torch::schema( "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"))) { @@ -1599,6 +1665,141 @@ REGISTER_OPERATOR_FUNCTOR(aten::linear, aten_linear, [](Node* n) -> SROperator { }; }); +REGISTER_OPERATOR_FUNCTOR(aten::fmod, aten_fmod, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor")) && + !n->matches(torch::schema( + "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& in0_t = p_node->Input(0).toTensor(); + const auto& in1_t = p_node->Input(1).isTensor() + ? p_node->Input(1).toTensor() + : at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar()); + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::cpu::fmod(in0_t, in1_t); + } else { + auto& out_t = p_node->Output(0).toTensor(); + fastResizeToZero(out_t); + + at::cpu::fmod_out(out_t, in0_t, in1_t); + } + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::linalg_norm, aten_linalg_norm, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor")) && + !n->matches(torch::schema( + "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + const auto dim = p_node->Input(2).toIntVector(); + const auto keepdim = p_node->Input(3).toBool(); + const auto dtype = p_node->Input(4).toOptional(); + + if (p_node->Output(0).isNone()) { + if (p_node->Input(1).isScalar()) { + p_node->Output(0) = at::native::linalg_norm( + input, + p_node->Input(1).toOptional(), + dim, + keepdim, + dtype); + } else { + p_node->Output(0) = at::native::linalg_norm( + input, p_node->Input(1).toStringView(), dim, keepdim, dtype); + } + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + + if (p_node->Input(1).isScalar()) { + at::native::linalg_norm_out( + input, + p_node->Input(1).toOptional(), + dim, + keepdim, + dtype, + output); + } else { + at::native::linalg_norm_out( + input, p_node->Input(1).toStringRef(), dim, keepdim, dtype, output); + } + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator { + if (!n->matches( + torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto inputs = p_node->Input(0).toTensorVector(); + const auto dim = p_node->Input(1).toInt(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::_cat_cpu(inputs, dim); + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + at::native::_cat_out_cpu(inputs, dim, output); + }; +}); + +REGISTER_OPERATOR_FUNCTOR(aten::cumsum, aten_cumsum, [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "aten::cumsum(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + const auto dim = p_node->Input(1).toInt(); + const auto dtype = p_node->Input(2).toOptional(); + + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::cpu::cumsum(input, dim, dtype); + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + at::cpu::cumsum_out(output, input, dim, dtype); + }; +}); + +REGISTER_OPERATOR_FUNCTOR( + aten::nonzero, + aten_nonzero, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema("aten::nonzero(Tensor self) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = at::native::nonzero_cpu(input); + return; + } + + auto& output = p_node->Output(0).toTensor(); + fastResizeToZero(output); + at::native::nonzero_out_cpu(input, output); + }; + }); + namespace { void check_cat_no_zero_dim(const std::vector& tensors) { @@ -1637,5 +1838,67 @@ REGISTER_OPERATOR_FUNCTOR( }; }); +namespace { + +// This template and its specialization help us avoid compiler warnings +// about taking the absolute value of an unsigned type in signed_log1p +template +T abs_if_signed(T val) { + return std::abs(val); +} + +template <> +unsigned char abs_if_signed(unsigned char val) { + return val; +} + +// Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor +void signed_log1p_out(at::Tensor& out, const at::Tensor& input) { + at::native::resize_(out, input.sizes(), c10::nullopt); + + const auto input_contig = input.expect_contiguous(); + auto output_contig = out.expect_contiguous(); + + AT_DISPATCH_ALL_TYPES(input.scalar_type(), "signed_log1p_kernel", [&]() { + const auto input_data = input_contig->data_ptr(); + auto output_data = output_contig->data_ptr(); + const auto N = input.numel(); + + for (const auto i : c10::irange(N)) { + const int sign = input_data[i] < 0 ? -1 : 1; + output_data[i] = std::log1p(abs_if_signed(input_data[i])) * sign; + } + }); +} + +at::Tensor signed_log1p(const at::Tensor& input) { + auto out = create_empty_from(input); + signed_log1p_out(out, input); + return out; +} + +} // namespace + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_OPERATOR_FUNCTOR( + static_runtime::signed_log1p, + static_runtime_signed_log1p, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema( + "static_runtime::signed_log1p(Tensor x) -> Tensor"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + if (p_node->Output(0).isNone()) { + p_node->Output(0) = signed_log1p(input); + } else { + auto& out = p_node->Output(0).toTensor(); + fastResizeToZero(out); + signed_log1p_out(out, input); + } + }; + }); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index ff5d69e1cb895..311143ca7392f 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -133,8 +133,12 @@ bool opIsRegistered(const c10::Symbol& op_name); // as native ops in Static Runtime bool nativeOpIsRegistered(const c10::Symbol& op_name); -bool canReuseInputsOutputs(Node* n); -bool isOptimizableContainerType(Node* n); +bool canReuseInputsOutputs( + Node* n, + const FastMap& node_has_out_variant); +bool isOptimizableContainerType( + Node* n, + const FastMap& node_has_out_variant); std::function getOutOfPlaceOperation(Node* n); std::function getNativeOperation(Node* n); diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index bbd7dd17f2feb..0eaebfdf0e7aa 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -12,7 +12,9 @@ namespace { bool HasInplaceOp(Block* block, const AliasDb& alias_db) { for (auto* node : block->nodes()) { for (Block* sub_block : node->blocks()) { - return HasInplaceOp(sub_block, alias_db); + if (HasInplaceOp(sub_block, alias_db)) { + return true; + } } auto inputs = node->inputs(); // check if node modifies inputs (both inplace ops and certain out variants @@ -165,125 +167,57 @@ C10_UNUSED void ClipRangesGather(std::shared_ptr& graph) { fuse.runOnGraph(graph); } +C10_UNUSED void PrecomputeMultiplierShiftForSigridHash( + std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor = fb::sigrid_hash(%a, %b, %c, %d) + return (%y0) + )IR"; + std::string split_pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor = fb::sigrid_hash_compute_multipler_shift(%c) + %y2 : Tensor = fb::sigrid_hash_precompute(%a, %b, %c, %y0, %d) + return (%y2) + )IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, split_pattern); + fuse.runOnGraph(graph); +} + C10_UNUSED void ClipRangesGatherSigridHash(std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere - std::string pattern_1 = R"IR( - graph(%a, %b, %c, %d, %e, %f, %g): - %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d) - %y2 : Tensor = fb::sigrid_hash(%y0, %e, %f, %g) - return (%y2, %y1))IR"; - std::string fused_pattern_1 = R"IR( - graph(%a, %b, %c, %d, %e, %f, %g): - %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_offsets(%b, %a, %c, %e, %f, %g, %d) - return (%out, %off))IR"; - - std::string pattern_2 = R"IR( + std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g, %h): %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d) %y2 : Tensor = fb::sigrid_hash_precompute(%y0, %e, %f, %g, %h) return (%y2, %y1))IR"; - std::string fused_pattern_2 = R"IR( + std::string fused_pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g, %h): %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_offsets(%b, %a, %c, %e, %f, %g, %h, %d) return (%out, %off))IR"; SubgraphRewriter fuse; - fuse.RegisterRewritePattern(pattern_1, fused_pattern_1); - fuse.runOnGraph(graph); - - fuse.RegisterRewritePattern(pattern_2, fused_pattern_2); + fuse.RegisterRewritePattern(pattern, fused_pattern); fuse.runOnGraph(graph); } C10_UNUSED void ClipRangesGatherRangesSigridHash( std::shared_ptr& graph) { - std::string pattern_1 = R"IR( - graph(%a, %b, %c, %d, %e, %f): - %y0 : Tensor = fb::clip_ranges(%b, %c) - %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) - %y3 : Tensor = fb::sigrid_hash(%y1, %d, %e, %f) - return (%y3, %y2))IR"; - std::string fused_pattern_1 = R"IR( - graph(%a, %b, %c, %d, %e, %f): - %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_v3(%b, %a, %c, %d, %e, %f) - return (%out, %off))IR"; - - std::string pattern_2 = R"IR( + std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g): %y0 : Tensor = fb::clip_ranges(%b, %c) %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) %y3 : Tensor = fb::sigrid_hash_precompute(%y1, %d, %e, %f, %g) return (%y3, %y2))IR"; - std::string fused_pattern_2 = R"IR( + std::string fused_pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g): %off : Tensor, %out : Tensor = fb::clip_ranges_gather_sigrid_hash_precompute_v3(%b, %a, %c, %d, %e, %f, %g) return (%out, %off))IR"; - SubgraphRewriter fuse; - fuse.RegisterRewritePattern(pattern_1, fused_pattern_1); - fuse.runOnGraph(graph); - - fuse.RegisterRewritePattern(pattern_2, fused_pattern_2); - fuse.runOnGraph(graph); -} - -C10_UNUSED void PrecomputeMultiplierShiftForSigridHash( - std::shared_ptr& graph) { - std::string pattern = R"IR( - graph(%a, %b, %c, %d): - %y0 : Tensor = fb::sigrid_hash(%a, %b, %c, %d) - return (%y0) - )IR"; - std::string split_pattern = R"IR( - graph(%a, %b, %c, %d): - %y0 : Tensor = fb::sigrid_hash_compute_multipler_shift(%c) - %y2 : Tensor = fb::sigrid_hash_precompute(%a, %b, %c, %y0, %d) - return (%y2) - )IR"; - SubgraphRewriter fuse; - fuse.RegisterRewritePattern(pattern, split_pattern); - fuse.runOnGraph(graph); -} - -C10_UNUSED void ClipRangesGatherRangesX2SigridHash( - std::shared_ptr& graph) { - // Placeholder is a dummy op used to capture the first subgraph - std::string pattern = R"IR( - graph(%ranges, %values, %max_length, %salt, %max_value, %hash_into_int32): - %clipped : Tensor = fb::clip_ranges(%ranges, %max_length) - %output : Tensor, %unused : Tensor = fb::gather_ranges(%values, %clipped) - %sigrid_hash_out : Tensor = fb::sigrid_hash(%output, %salt, %max_value, %hash_into_int32) - return (%sigrid_hash_out, %clipped))IR"; - std::string fused_pattern = R"IR( - graph(%ranges, %values, %max_length, %salt, %max_value, %hash_into_int32): - %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %hash_into_int32) - return (%sigrid_hash_out, %clipped))IR"; - - // the second gather_ranges can be eliminated because the `lengths` is - // produces is identical to the lengths produced by - // clip_ranges_gather_sigrid_hash_v3 (caveat, the fused ops makes some - // simplifying assumptions about the ranges input) - std::string pattern2 = R"IR( - graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %hash_into_int32): - %sigrid_hash_out : Tensor, %clipped : Tensor = fb::placeholder(%ranges, %values, %max_length, %salt, %max_value, %hash_into_int32) - %unused : Tensor, %lengths : Tensor = fb::gather_ranges(%gather2_values, %clipped) - return (%lengths, %sigrid_hash_out))IR"; - - std::string fused_pattern2 = R"IR( - graph(%gather2_values, %ranges, %values, %max_length, %salt, %max_value, %hash_into_int32): - %lengths : Tensor, %sigrid_hash_out : Tensor = fb::clip_ranges_gather_sigrid_hash_v3(%ranges, %values, %max_length, %salt, %max_value, %hash_into_int32) - return (%lengths, %sigrid_hash_out))IR"; - SubgraphRewriter fuse; fuse.RegisterRewritePattern(pattern, fused_pattern); fuse.runOnGraph(graph); - - fuse.RegisterRewritePattern(pattern2, fused_pattern2); - fuse.runOnGraph(graph); - - // reverse the ops that got fused in step 1 but not in step2 - fuse.RegisterRewritePattern(fused_pattern, pattern); - fuse.runOnGraph(graph); } C10_UNUSED void ClipRangesGatherRangesX2SigridHashPrecompute( @@ -349,7 +283,6 @@ void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { ClipRangesGatherSigridHash(graph); ClipRangesGatherRangesSigridHash(graph); - ClipRangesGatherRangesX2SigridHash(graph); ClipRangesGatherRangesX2SigridHashPrecompute(graph); // prioritize clip_ranges+gather_ranges+sigrid_hash fusion over @@ -370,6 +303,31 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) { "static_runtime::to_copy.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); m.def( "static_runtime::to_copy.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + m.def(torch::schema( + "static_runtime::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> (Tensor, Tensor, Tensor)", + c10::AliasAnalysisKind::PURE_FUNCTION)); + m.def("static_runtime::signed_log1p(Tensor input) -> Tensor"); +} + +void FuseSignLog1P(std::shared_ptr& graph) { + std::string pattern = R"IR( + graph(%input): + %0 : Tensor = aten::sign(%input) + %1 : Tensor = aten::abs(%input) + %2 : Tensor = aten::log1p(%1) + %res : Tensor = aten::mul(%0, %2) + return (%res) + )IR"; + + std::string fused_pattern = R"IR( + graph(%input): + %res : Tensor = static_runtime::signed_log1p(%input) + return (%res) + )IR"; + + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); } bool HasInplaceOp(std::shared_ptr& graph, const AliasDb& alias_db) { @@ -479,6 +437,7 @@ void ReplaceWithCopy( // c10::AliasAnalysisKind::PURE_FUNCTION to make alias analysis work. void FuseListUnpack(std::shared_ptr& graph) { auto nodes = graph->nodes(); + std::vector equally_splits_to_remove; for (auto it = nodes.begin(); it != nodes.end(); ++it) { Node* node = *it; const char* node_qual_string = node->kind().toQualString(); @@ -512,8 +471,22 @@ void FuseListUnpack(std::shared_ptr& graph) { it_next.destroyCurrent(); // remove list_unpack node->eraseOutput(0); + + if (strcmp(node_qual_string, "fb::equally_split") == 0 && + node->outputs().size() == 1) { + // This captures a case of `y = fb::equally_split(x, 1, _)` where y + // becomes just an alias of x. + // If this case is found, replace y with x to avoid executing this op. + equally_splits_to_remove.push_back(node); + } } } + + for (Node* node : equally_splits_to_remove) { + node->output(0)->replaceAllUsesWith(node->input(0)); + node->destroy(); + } + #ifndef NDEBUG graph->lint(); AliasDb db2(graph); @@ -521,5 +494,35 @@ void FuseListUnpack(std::shared_ptr& graph) { #endif } +void EnableStaticRuntimeLayerNorm(std::shared_ptr& graph) { + const c10::Symbol static_runtime_layer_norm_symbol = + c10::Symbol::fromQualString("static_runtime::layer_norm"); + auto nodes = graph->nodes(); + std::vector> replacement; + for (auto it = nodes.begin(); it != nodes.end(); ++it) { + Node* old_node = *it; + if (!old_node->matches(torch::schema( + "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor"))) { + continue; + } + TORCH_CHECK(old_node->outputs().size() == 1); + auto* new_node = graph->create( + static_runtime_layer_norm_symbol, + /*layer_norm*/ 1 + /*mean*/ 1 + /*rst=*/1); + new_node->insertBefore(old_node); + for (auto* input : old_node->inputs()) { + new_node->addInput(input); + } + replacement.emplace_back(old_node, new_node); + } + for (const auto& p : replacement) { + auto* old_node = p.first; + auto* new_node = p.second; + new_node->output(0)->copyMetadata(old_node->output(0)); + old_node->output(0)->replaceAllUsesWith(new_node->output(0)); + old_node->destroy(); + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index 11ab4bdc7c46a..0904d37fb02c4 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -13,9 +13,14 @@ TORCH_API void ReplaceWithCopy( std::shared_ptr& graph, bool outputs_are_immutable = true); +TORCH_API void EnableStaticRuntimeLayerNorm( + std::shared_ptr& graph); + TORCH_API bool HasInplaceOp( std::shared_ptr& graph, const AliasDb& alias_db); +TORCH_API void FuseSignLog1P(std::shared_ptr& graph); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/static/te_wrapper.cpp b/torch/csrc/jit/runtime/static/te_wrapper.cpp new file mode 100644 index 0000000000000..acd1fb758da0a --- /dev/null +++ b/torch/csrc/jit/runtime/static/te_wrapper.cpp @@ -0,0 +1,183 @@ +#include + +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +// Use the width of an AVX-512 vector by default; this happens to work OK for +// AVX2 as well. Some ops benefit from using multiple AVX ports, in which case +// they are vectorized by twice this constant. An exception is logit, since it +// contains FP divide, which is single-ported. +static constexpr int kVectorWidth = 16; + +#ifdef TORCH_ENABLE_LLVM + +void TEWrapper::update(std::unique_ptr&& cg_) { + cg = std::move(cg_); +} + +void TEWrapper::call(const std::vector& args) { + cg->call_raw(args); +} + +bool TEWrapper::supports(const at::Tensor& t) { + return t.is_contiguous() && t.dtype().Match(); +} + +void optimizePointwise(LoopNest* ln, Tensor target, int width) { + std::vector loops = ln->getLoopStmtsFor(target); + ForPtr inner, tail; + TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); + ln->splitWithTail(loops[0], width, &inner, &tail); + ln->vectorize(inner); +} + +std::shared_ptr wrapTECompute( + std::shared_ptr wrap, + Tensor out, + std::vector args, + int width = kVectorWidth) { + LoopNest ln({out}); + optimizePointwise(&ln, out, width); + ln.prepareForCodegen(); + StmtPtr s = ln.root_stmt(); + s = IRSimplifier::simplify(s); + args.insert(args.begin(), out); + auto cg = std::make_unique(s, args); + wrap->update(std::move(cg)); + return wrap; +} + +#else + +void TEWrapper::call(const std::vector& args) { + DCHECK(0 && "Invalid call"); +} + +bool TEWrapper::supports(const at::Tensor& t) { + return false; +} + +std::shared_ptr wrapTECompute( + std::shared_ptr wrap, + Tensor out, + std::vector args, + int width = kVectorWidth) { + return wrap; +} + +#endif + +namespace { + +std::mutex& getNNCCacheMutex() { + static std::mutex nncCacheMutex; + return nncCacheMutex; +} + +FastMap>& getNNCCache() { + static FastMap> nncCache; + return nncCache; +} + +std::shared_ptr lookupNNCCache(NodeKind kind) { + std::lock_guard lock(getNNCCacheMutex()); + auto it = getNNCCache().find(kind); + if (it != getNNCCache().end()) { + return it->second; + } + return nullptr; +} + +void updateNNCCache(NodeKind kind, std::shared_ptr code) { + std::lock_guard lock(getNNCCacheMutex()); + getNNCCache()[kind] = code; +} + +} // namespace + +std::shared_ptr createLogit() { + auto wrap = lookupNNCCache(aten::logit); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + auto C = VarHandle("C", kFloat); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto A_elem = [&]() { + auto elem = A.load(i); + auto one = FloatImm::make(1.0f); + const auto& min = C; + auto max = one - C; + elem = CompareSelect::make(elem, min, min, elem, kLT); + return CompareSelect::make(elem, max, max, elem, kGT); + }(); + return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem)); + }); + wrap = wrapTECompute(wrap, B, {A, N, C}); + updateNNCCache(aten::logit, wrap); + return wrap; +} + +std::shared_ptr createRelu() { + auto wrap = lookupNNCCache(aten::relu); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto zero = FloatImm::make(0.f); + auto a = A.load(i); + return ifThenElse(a < zero, zero, a); + }); + wrap = wrapTECompute(wrap, B, {A, N}); + updateNNCCache(aten::relu, wrap); + return wrap; +} + +std::shared_ptr createTanh() { + auto wrap = lookupNNCCache(aten::tanh); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto a = A.load(i); + return fast_tanh(a); + }); + wrap = wrapTECompute(wrap, B, {A, N}); + updateNNCCache(aten::tanh, wrap); + return wrap; +} + +std::shared_ptr createSigmoid() { + auto wrap = lookupNNCCache(aten::sigmoid); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = + Compute("B", {N}, [&](const VarHandle& i) { return sigmoid(A.load(i)); }); + // NNC uses sleef for vectorizing sigmoid, which comes in an 8-wide flavor + // (Sleef_expf8). + constexpr int kSleefWidth = 8; + wrap = wrapTECompute(wrap, B, {A, N}, kSleefWidth); + updateNNCCache(aten::sigmoid, wrap); + return wrap; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/te_wrapper.h b/torch/csrc/jit/runtime/static/te_wrapper.h new file mode 100644 index 0000000000000..0a5f3d8532990 --- /dev/null +++ b/torch/csrc/jit/runtime/static/te_wrapper.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +class TEWrapper { + public: + TEWrapper() = default; + void call(const std::vector& args); + bool supports(const at::Tensor& t); +#ifdef TORCH_ENABLE_LLVM + void update(std::unique_ptr&& cg_); +#endif + + private: +#ifdef TORCH_ENABLE_LLVM + std::unique_ptr cg; +#endif +}; + +std::shared_ptr createLogit(); +std::shared_ptr createRelu(); +std::shared_ptr createTanh(); +std::shared_ptr createSigmoid(); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 453a83cd4507e..6f2acca134738 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1117,7 +1117,7 @@ const std::vector functions = { return result, backward )", R"( - def batch_norm_disabled(input : Tensor, + def batch_norm(input : Tensor, weight : Optional[Tensor], bias : Optional[Tensor], running_mean : Optional[Tensor], @@ -1141,7 +1141,7 @@ const std::vector functions = { return output, backward - def layer_norm(input : Tensor, + def layer_norm_disabled(input : Tensor, normalized_shape : List[int], weight : Optional[Tensor], bias : Optional[Tensor], diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index ffc2f44e16dac..871b65d75f6b7 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace torch { @@ -36,10 +37,17 @@ const std::string shape_compute_functions = return expandedSizes def adaptive_avg_pool2d(self: List[int], out: List[int]): - # TODO: return out directly, list len refiner would need to - # annotate the List Type with len directly in IR assert len(out) == 2 - return [out[0], out[1]] + assert len(self) == 3 or len(self) == 4 + for i in range (1, len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(0, len(self) -2): + shape.append(self[i]) + for elem in out: + shape.append(elem) + return shape # TODO: maybe make it customary that extra arguments are unused ? # TODO: return self directly @@ -285,7 +293,21 @@ const std::string shape_compute_functions = for i in range(end_dim + 1, len(input)): shape.append(input[i]) return shape - )"; + )" +#ifdef USE_XNNPACK + R"( + def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any): + assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext) + (weight, bias, stride, padding, dilation, groups) = ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext) + return conv2d(input, weight, bias, stride, padding, dilation, groups) + + def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any): + assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext) + (weight, bias) = ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext) + return linear(input, weight, bias) + )" +#endif + ; // mapping function schema to shape compute graphs allows multiple functions to // share the same shape compute graph, which is memory efficient and also will @@ -310,8 +332,11 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary_one_unused_input"}, {"aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"}, {"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast_one_unused_input"}, + {"aten::add_.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast_one_unused_input"}, {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary_two_unused_inputs"}, {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "unary_two_unused_inputs"}, + {"aten::hardswish_(Tensor self) -> Tensor", "unary"}, + {"aten::hardsigmoid_(Tensor self) -> Tensor", "unary"}, {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "adaptive_avg_pool2d"}, {"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"}, {"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"}, @@ -328,6 +353,10 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "view"}, {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"}, +#ifdef USE_XNNPACK + {"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y", "prepacked_conv2d_clamp_run"}, + {"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y", "prepacked_linear_clamp_run"}, +#endif }; // clang-format on return schema_to_function_graph; @@ -337,7 +366,7 @@ std::unordered_map> cached_schema_to_graph; // CompilationUnit that holds all these Functions and keeps them alive. -CompilationUnit compilation_unit; +auto compilation_unit = std::make_shared(); void loadModule(const CompilationUnit& module) { std::unordered_map> reused_functions; @@ -364,9 +393,16 @@ void loadModule(const CompilationUnit& module) { } void loadFunctions() { - compilation_unit.define( - c10::nullopt, shape_compute_functions, nativeResolver(), nullptr); - loadModule(compilation_unit); + auto src = std::make_shared(shape_compute_functions); + std::vector constantTable; + auto resolver = std::make_shared( + compilation_unit, + &constantTable, + [&](const std::string& name) -> std::shared_ptr { return src; }, + 1); + compilation_unit->define( + c10::nullopt, shape_compute_functions, resolver, nullptr); + loadModule(*compilation_unit); } } // anonymous namespace diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index c26c7e575c547..93da38ad768c5 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -47,15 +47,11 @@ c10::IValue InlinedCallStackSerializer::serialize( } else { elements.emplace_back(c10::IValue()); } - if (cs_ptr->function()) { - elements.emplace_back(cs_ptr->function()->name()); + auto fn_name = cs_ptr->function_name(); + if (!fn_name.empty()) { + elements.emplace_back(fn_name); } else { - auto fn_name = cs_ptr->function_name(); - if (!fn_name.empty()) { - elements.emplace_back(fn_name); - } else { - elements.emplace_back("FunctionName_UNKNOWN"); - } + elements.emplace_back("FunctionName_UNKNOWN"); } c10::IValue serialized_cs = c10::ivalue::Tuple::create(elements); serialized_inlined_callstack_[cs_ptr] = serialized_cs; diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 1f70c3cad8a5e..86aa6e3909e14 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -47,7 +47,8 @@ void postSetStateValidate(const IValue& v) { // const auto attrType = objType->getAttribute(i); // Verify that all the non-optional attributes have been initialized // TODO: Issue #20497 - if (attrType->kind() != TypeKind::OptionalType && + if (attrType->kind() != TypeKind::UnionType && + attrType->kind() != TypeKind::OptionalType && attrType->kind() != TypeKind::NoneType) { TORCH_CHECK( !slot.isNone(), diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp index fb1de17a54eea..918b0d4338c73 100644 --- a/torch/csrc/jit/serialization/import_source.cpp +++ b/torch/csrc/jit/serialization/import_source.cpp @@ -91,629 +91,616 @@ struct ConstantTableValue : public SugaredValue { const std::vector* constants_; }; -struct SourceImporterImpl : public Resolver, - std::enable_shared_from_this { - SourceImporterImpl( - std::shared_ptr cu, - const std::vector* constant_table, - SourceLoader source_loader, - size_t version) - : cu_(std::move(cu)), source_loader_(std::move(source_loader)) { - env_ = { - {"torch", std::make_shared("aten", version)}, - {"ops", std::make_shared(version)}, - // Constants present in the model. Used to resolve "CONSTANTS.n" to the - // actual value - {"CONSTANTS", std::make_shared(constant_table)}, - {"fork", SpecialFormValue::create(prim::fork)}, - {"annotate", SpecialFormValue::create(prim::annotate)}, - {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)}, - {"uninitialized", SpecialFormValue::create(prim::Uninitialized)}, - }; - } - - TypePtr findNamedType(const QualifiedName& name) { - if (auto custom_class = getCustomClass(name.qualifiedName())) { - return custom_class; - } - parseSourceIfNeeded(name.prefix()); - auto it = to_be_defined_.find(name); - if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) { - ClassDef cd(it->second); - to_be_defined_.erase(it); - importNamedType(name.prefix(), cd); - } - return cu_->get_type(name); +SourceImporterImpl::SourceImporterImpl( + std::shared_ptr cu, + const std::vector* constant_table, + SourceLoader source_loader, + size_t version) + : cu_(std::move(cu)), source_loader_(std::move(source_loader)) { + env_ = { + {"torch", std::make_shared("aten", version)}, + {"ops", std::make_shared(version)}, + // Constants present in the model. Used to resolve "CONSTANTS.n" to the + // actual value + {"CONSTANTS", std::make_shared(constant_table)}, + {"fork", SpecialFormValue::create(prim::fork)}, + {"annotate", SpecialFormValue::create(prim::annotate)}, + {"unchecked_cast", SpecialFormValue::create(prim::unchecked_cast)}, + {"uninitialized", SpecialFormValue::create(prim::Uninitialized)}, + }; +} + +TypePtr SourceImporterImpl::findNamedType(const QualifiedName& name) { + if (auto custom_class = getCustomClass(name.qualifiedName())) { + return custom_class; + } + parseSourceIfNeeded(name.prefix()); + auto it = to_be_defined_.find(name); + if (it != to_be_defined_.end() && it->second->kind() == TK_CLASS_DEF) { + ClassDef cd(it->second); + to_be_defined_.erase(it); + importNamedType(name.prefix(), cd); } + return cu_->get_type(name); +} - Function* findFunction(const QualifiedName& name) { - parseSourceIfNeeded(name.prefix()); - auto it = to_be_defined_.find(name); - if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) { - Def d(it->second); - to_be_defined_.erase(it); - importFunction(name.prefix(), d); - } - return cu_->find_function(name); +Function* SourceImporterImpl::findFunction(const QualifiedName& name) { + parseSourceIfNeeded(name.prefix()); + auto it = to_be_defined_.find(name); + if (it != to_be_defined_.end() && it->second->kind() == TK_DEF) { + Def d(it->second); + to_be_defined_.erase(it); + importFunction(name.prefix(), d); } + return cu_->find_function(name); +} - void parseSourceIfNeeded(const std::string& qualifier) { - // qualifier may be blank, for instance checking if __torch__ is a class. - if (qualifier == "" || loaded_sources_.count(qualifier)) { - return; - } - loaded_sources_.insert(qualifier); - std::shared_ptr src = source_loader_(qualifier); - - // The importer, when looking for classes/functions doesn't know if 'foo' - // contains definitions or if it is a prefix of 'foo.bar', we only figure it - // out by testing if `foo.py` exists in the source loader. If it doesn't - // then there is nothing to load here - if (!src) { - return; - } - Parser p(src); - parsePossibleVersionNumber(p.lexer()); - - auto& L = p.lexer(); - - while (L.cur().kind != TK_EOF) { - parseImports(L); - auto tk = L.cur(); - auto kind = tk.kind; - switch (kind) { - case TK_CLASS_DEF: { - auto parsed_treeref = ClassDef(p.parseClass()); - to_be_defined_[QualifiedName( - qualifier, parsed_treeref.name().name())] = parsed_treeref; - } break; - case TK_DEF: { - auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false)); - to_be_defined_[QualifiedName( - qualifier, parsed_treeref.name().name())] = parsed_treeref; - } break; - default: - throw ErrorReport(L.cur().range) - << "Unexpected token in code import: " << kindToString(kind); - } +void SourceImporterImpl::parseSourceIfNeeded(const std::string& qualifier) { + // qualifier may be blank, for instance checking if __torch__ is a class. + if (qualifier == "" || loaded_sources_.count(qualifier)) { + return; + } + loaded_sources_.insert(qualifier); + std::shared_ptr src = source_loader_(qualifier); + + // The importer, when looking for classes/functions doesn't know if 'foo' + // contains definitions or if it is a prefix of 'foo.bar', we only figure it + // out by testing if `foo.py` exists in the source loader. If it doesn't + // then there is nothing to load here + if (!src) { + return; + } + Parser p(src); + parsePossibleVersionNumber(p.lexer()); + + auto& L = p.lexer(); + + while (L.cur().kind != TK_EOF) { + parseImports(L); + auto tk = L.cur(); + auto kind = tk.kind; + switch (kind) { + case TK_CLASS_DEF: { + auto parsed_treeref = ClassDef(p.parseClass()); + to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = + parsed_treeref; + } break; + case TK_DEF: { + auto parsed_treeref = Def(p.parseFunction(/*is_method=*/false)); + to_be_defined_[QualifiedName(qualifier, parsed_treeref.name().name())] = + parsed_treeref; + } break; + default: + throw ErrorReport(L.cur().range) + << "Unexpected token in code import: " << kindToString(kind); } } +} - void LEGACY_import_methods( - const Module& mod, - const std::shared_ptr& src) { - auto self = SimpleSelf(mod.type()); - c10::QualifiedName prefix = *mod.type()->name(); - Parser p(src); +void SourceImporterImpl::LEGACY_import_methods( + const Module& mod, + const std::shared_ptr& src) { + auto self = SimpleSelf(mod.type()); + c10::QualifiedName prefix = *mod.type()->name(); + Parser p(src); - parsePossibleVersionNumber(p.lexer()); + parsePossibleVersionNumber(p.lexer()); - parseImports(p.lexer()); + parseImports(p.lexer()); - std::vector definitions; - std::vector resolvers; - while (p.lexer().cur().kind != TK_EOF) { - auto def = Def(p.parseFunction(/*is_method=*/true)); - definitions.emplace_back(def); - resolvers.emplace_back(shared_from_this()); - } - cu_->define( - prefix, - /*properties=*/{}, - /*propResolvers=*/{}, - definitions, - resolvers, - &self); + std::vector definitions; + std::vector resolvers; + while (p.lexer().cur().kind != TK_EOF) { + auto def = Def(p.parseFunction(/*is_method=*/true)); + definitions.emplace_back(def); + resolvers.emplace_back(shared_from_this()); } + cu_->define( + prefix, + /*properties=*/{}, + /*propResolvers=*/{}, + definitions, + resolvers, + &self); +} - std::shared_ptr resolveValue( - const std::string& name, - Function& m, - const SourceRange& loc) override { - auto it = env_.find(name); - if (it != env_.end()) { - return it->second; - } - auto graph = m.graph(); - if (name == "inf") { - return std::make_shared( - graph->insertConstant(std::numeric_limits::infinity(), loc)); - } - if (name == "nan") { - return std::make_shared( - graph->insertConstant(std::numeric_limits::quiet_NaN(), loc)); - } - if (name == "infj") { - return std::make_shared(graph->insertConstant( - c10::complex(0, std::numeric_limits::infinity()), - loc)); - } - if (name == "nanj") { - return std::make_shared(graph->insertConstant( - c10::complex(0, std::numeric_limits::quiet_NaN()), - loc)); - } - if (name == "__torch__") { - return std::make_shared( - c10::QualifiedName(name), shared_from_this()); - } - return nullptr; +std::shared_ptr SourceImporterImpl::resolveValue( + const std::string& name, + Function& m, + const SourceRange& loc) { + auto it = env_.find(name); + if (it != env_.end()) { + return it->second; + } + auto graph = m.graph(); + if (name == "inf") { + return std::make_shared( + graph->insertConstant(std::numeric_limits::infinity(), loc)); + } + if (name == "nan") { + return std::make_shared( + graph->insertConstant(std::numeric_limits::quiet_NaN(), loc)); } + if (name == "infj") { + return std::make_shared(graph->insertConstant( + c10::complex(0, std::numeric_limits::infinity()), loc)); + } + if (name == "nanj") { + return std::make_shared(graph->insertConstant( + c10::complex(0, std::numeric_limits::quiet_NaN()), + loc)); + } + if (name == "__torch__") { + return std::make_shared( + c10::QualifiedName(name), shared_from_this()); + } + return nullptr; +} + +TypePtr SourceImporterImpl::resolveType( + const std::string& name, + const SourceRange& loc) { + return findNamedType(QualifiedName(name)); +} - TypePtr resolveType(const std::string& name, const SourceRange& loc) - override { - return findNamedType(QualifiedName(name)); +void SourceImporterImpl::importFunction( + const std::string& qualifier, + const Def& def) { + std::vector definitions{def}; + std::vector resolvers{shared_from_this()}; + cu_->define( + qualifier, + /*properties=*/{}, + /*propResolvers=*/{}, + definitions, + resolvers, + nullptr); +} + +void SourceImporterImpl::importNamedType( + const std::string& qualifier, + const ClassDef& class_def) { + const auto qualified_name = + QualifiedName(QualifiedName(qualifier), class_def.name().name()); + if (!class_def.superclass().present()) { + return importClass(qualified_name, class_def, /*is_module=*/false); + } + const auto& superclass_name = Var(class_def.superclass().get()).name().name(); + if (superclass_name == "Module") { + importClass(qualified_name, class_def, /*is_module=*/true); + } else if (superclass_name == "NamedTuple") { + // NamedTuples have special rules (since they are TupleTypes and not + // ClassTypes) + return importNamedTuple(qualified_name, class_def); + } else if (superclass_name == "Interface") { + cu_->define_interface( + qualified_name, class_def, shared_from_this(), /*is_module=*/false); + } else if (superclass_name == "ModuleInterface") { + cu_->define_interface( + qualified_name, class_def, shared_from_this(), /*is_module=*/true); + } else if (superclass_name == "Enum") { + importEnum(qualified_name, class_def); + } else { + throw ErrorReport(class_def.range()) + << "Torchscript does not support class inheritance."; } +} - private: - void importFunction(const std::string& qualifier, const Def& def) { - std::vector definitions{def}; - std::vector resolvers{shared_from_this()}; - cu_->define( - qualifier, - /*properties=*/{}, - /*propResolvers=*/{}, - definitions, - resolvers, - nullptr); - } - - void importNamedType( - const std::string& qualifier, - const ClassDef& class_def) { - const auto qualified_name = - QualifiedName(QualifiedName(qualifier), class_def.name().name()); - if (!class_def.superclass().present()) { - return importClass(qualified_name, class_def, /*is_module=*/false); +c10::optional SourceImporterImpl:: + attributeAssignmentSpecialHandlingHack( + const QualifiedName& qualified_classname, + const Assign& assign) { + struct AttrTypeReplacementDescr { + std::string attr_name; + std::string expected_type; + std::string replacement_type; + }; + + // module demangled qualname -> ReplacementDescr + static std::unordered_map replacements{ + {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, + {"__torch__.torch.nn.quantized.modules.linear.Linear", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, + {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, + {"__torch__.torch.nn.quantized.modules.conv.Conv2d", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, + {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, + {"__torch__.torch.nn.quantized.modules.conv.Conv3d", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, + {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d", + {"_packed_params", + "Tensor", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}}; + // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful + static std::regex mangle_re("\\.___torch_mangle_\\d+"); + auto demangled_classname = + std::regex_replace(qualified_classname.qualifiedName(), mangle_re, ""); + if (replacements.count(demangled_classname)) { + auto lhs = Var(assign.lhs()); + if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { + return c10::nullopt; } - const auto& superclass_name = - Var(class_def.superclass().get()).name().name(); - if (superclass_name == "Module") { - importClass(qualified_name, class_def, /*is_module=*/true); - } else if (superclass_name == "NamedTuple") { - // NamedTuples have special rules (since they are TupleTypes and not - // ClassTypes) - return importNamedTuple(qualified_name, class_def); - } else if (superclass_name == "Interface") { - cu_->define_interface( - qualified_name, class_def, shared_from_this(), /*is_module=*/false); - } else if (superclass_name == "ModuleInterface") { - cu_->define_interface( - qualified_name, class_def, shared_from_this(), /*is_module=*/true); - } else if (superclass_name == "Enum") { - importEnum(qualified_name, class_def); - } else { - throw ErrorReport(class_def.range()) - << "Torchscript does not support class inheritance."; + auto type = Var(assign.type().get()); + + auto& attr_name = replacements.at(demangled_classname).attr_name; + auto& expected_type = replacements.at(demangled_classname).expected_type; + auto& replacement_type = + replacements.at(demangled_classname).replacement_type; + if (lhs.name().name() == attr_name && type.name().name() == expected_type) { + Parser p(std::make_shared(replacement_type)); + auto typename_expr = p.parseExp(); + auto maybe_typename = + Maybe::create(typename_expr.range(), typename_expr); + return Assign::create( + assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); } } + return c10::nullopt; +} - c10::optional attributeAssignmentSpecialHandlingHack( - const QualifiedName& qualified_classname, - const Assign& assign) { - struct AttrTypeReplacementDescr { - std::string attr_name; - std::string expected_type; - std::string replacement_type; - }; - - // module demangled qualname -> ReplacementDescr - static std::unordered_map replacements{ - {"__torch__.torch.nn.quantized.modules.linear.LinearPackedParams", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, - {"__torch__.torch.nn.quantized.modules.linear.Linear", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, - {"__torch__.torch.nn.quantized.dynamic.modules.linear.Linear", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.LinearPackedParamsBase"}}, - {"__torch__.torch.nn.quantized.modules.conv.Conv2d", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, - {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.Conv2dPackedParamsBase"}}, - {"__torch__.torch.nn.quantized.modules.conv.Conv3d", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}, - {"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d", - {"_packed_params", - "Tensor", - "__torch__.torch.classes.quantized.Conv3dPackedParamsBase"}}}; - static std::regex mangle_re("\\.___torch_mangle_\\d+"); - auto demangled_classname = - std::regex_replace(qualified_classname.qualifiedName(), mangle_re, ""); - if (replacements.count(demangled_classname)) { - auto lhs = Var(assign.lhs()); - if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { - return c10::nullopt; - } - auto type = Var(assign.type().get()); - - auto& attr_name = replacements.at(demangled_classname).attr_name; - auto& expected_type = replacements.at(demangled_classname).expected_type; - auto& replacement_type = - replacements.at(demangled_classname).replacement_type; - if (lhs.name().name() == attr_name && - type.name().name() == expected_type) { - Parser p(std::make_shared(replacement_type)); - auto typename_expr = p.parseExp(); - auto maybe_typename = - Maybe::create(typename_expr.range(), typename_expr); - return Assign::create( - assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); - } - } - return c10::nullopt; - } - - void importClass( - const QualifiedName& qualified_classname, - const ClassDef& class_def, - bool is_module) { - // BC for TorchBind classes - // - // Previously we would serialize TorchBind classes as actual - // classes with methods that delegate to things in the - // torch.ops.* namespace. We've switched away from this and - // now just rely on those classes being present in the binary - // and emit code for them based on the ClassType in memory. - // - // TODO: remove this once we no longer have old TorchBind code - // in production models - { - static QualifiedName torch_classes_qualname("__torch__.torch.classes"); - if (torch_classes_qualname.isPrefixOf(qualified_classname)) { - return; - } +void SourceImporterImpl::importClass( + const QualifiedName& qualified_classname, + const ClassDef& class_def, + bool is_module) { + // BC for TorchBind classes + // + // Previously we would serialize TorchBind classes as actual + // classes with methods that delegate to things in the + // torch.ops.* namespace. We've switched away from this and + // now just rely on those classes being present in the binary + // and emit code for them based on the ClassType in memory. + // + // TODO: remove this once we no longer have old TorchBind code + // in production models + { + static QualifiedName torch_classes_qualname("__torch__.torch.classes"); + if (torch_classes_qualname.isPrefixOf(qualified_classname)) { + return; } - auto class_type = ClassType::create( - c10::QualifiedName(qualified_classname), cu_, is_module); - - std::vector methods; - std::vector method_resolvers; - std::map pre_hook_def_map; - std::map hook_def_map; - std::map pre_hook_resolver_map; - std::map hook_resolver_map; - std::vector attributes; - std::vector constants; - - // Module-specific: which attrs are parameters? - std::unordered_set parameter_names; - std::unordered_set buffer_names; - std::unordered_set pre_hook_names; - std::unordered_set hook_names; - // used to keep track of original ordering of hooks and prehooks - // in case any are called more than once - std::vector pre_hooks_order; - std::vector hooks_order; - // Process statements, splitting things into attribute and method - // definitions. - for (const auto& statement : class_def.body()) { - switch (statement.kind()) { - case TK_ASSIGN: { - const auto assign = Assign(statement); - switch (assign.lhs().kind()) { - case TK_VAR: { - const auto name = Var(assign.lhs()).name().name(); - if (name == "__parameters__") { - // Populate the module parameter list. This is a field that - // looks like: - // __parameters__ = ["foo", "bar", "baz"] - // which tells us which attributes are module parameters. - TORCH_INTERNAL_ASSERT( - is_module, - "Assignments in class body only " - "supported on modules right now"); - const auto param_list = - ListLiteral(assign.rhs().get()).inputs(); - for (const auto& param : param_list) { - parameter_names.insert(StringLiteral(param).text()); - } - } else if (name == "__annotations__") { - // This is to initialize the annotations dict, just ignore. - continue; - } else if (name == "__buffers__") { - TORCH_INTERNAL_ASSERT( - is_module, "Buffers only exist on modules at the moment"); - const auto buffer_list = - ListLiteral(assign.rhs().get()).inputs(); - for (const auto& buffer : buffer_list) { - buffer_names.insert(StringLiteral(buffer).text()); - } - } else if (name == "__forward_pre_hooks__") { - TORCH_INTERNAL_ASSERT( - is_module, - "Forward pre hooks only exist on modules at the moment"); - const auto pre_hook_list = - ListLiteral(assign.rhs().get()).inputs(); - for (const auto& pre_hook : pre_hook_list) { - std::string pre_hook_name = StringLiteral(pre_hook).text(); - pre_hook_names.insert(pre_hook_name); - pre_hooks_order.emplace_back(pre_hook_name); - } - } else if (name == "__forward_hooks__") { - TORCH_INTERNAL_ASSERT( - is_module, - "Forward hooks only exist on modules at the moment"); - const auto hook_list = ListLiteral(assign.rhs().get()).inputs(); - for (const auto& hook : hook_list) { - std::string hook_name = StringLiteral(hook).text(); - hook_names.insert(hook_name); - hooks_order.emplace_back(hook_name); - } - } else { - if (auto fixed_up = attributeAssignmentSpecialHandlingHack( - qualified_classname, assign)) { - attributes.push_back(std::move(*fixed_up)); - } else if (assign.rhs().present()) { - // This is a constant assignment, of the form: - // foo : Final[int] = 3 - constants.push_back(assign); - } else { - // This is a regular attribute assignment, of the form: - // foo : Tensor - attributes.push_back(assign); - } + } + auto class_type = ClassType::create( + c10::QualifiedName(qualified_classname), cu_, is_module); + + std::vector methods; + std::vector method_resolvers; + std::map pre_hook_def_map; + std::map hook_def_map; + std::map pre_hook_resolver_map; + std::map hook_resolver_map; + std::vector attributes; + std::vector constants; + + // Module-specific: which attrs are parameters? + std::unordered_set parameter_names; + std::unordered_set buffer_names; + std::unordered_set pre_hook_names; + std::unordered_set hook_names; + // used to keep track of original ordering of hooks and prehooks + // in case any are called more than once + std::vector pre_hooks_order; + std::vector hooks_order; + // Process statements, splitting things into attribute and method + // definitions. + for (const auto& statement : class_def.body()) { + switch (statement.kind()) { + case TK_ASSIGN: { + const auto assign = Assign(statement); + switch (assign.lhs().kind()) { + case TK_VAR: { + const auto name = Var(assign.lhs()).name().name(); + if (name == "__parameters__") { + // Populate the module parameter list. This is a field that + // looks like: + // __parameters__ = ["foo", "bar", "baz"] + // which tells us which attributes are module parameters. + TORCH_INTERNAL_ASSERT( + is_module, + "Assignments in class body only " + "supported on modules right now"); + const auto param_list = ListLiteral(assign.rhs().get()).inputs(); + for (const auto& param : param_list) { + parameter_names.insert(StringLiteral(param).text()); } - } break; - case TK_SUBSCRIPT: { - // This is a special attribute assignment where the attribute - // is not a valid python, identifier. Looks like: - // __annotations__["0"] = Tensor - const auto lhs = Subscript(assign.lhs()); + } else if (name == "__annotations__") { + // This is to initialize the annotations dict, just ignore. + continue; + } else if (name == "__buffers__") { TORCH_INTERNAL_ASSERT( - Var(lhs.value()).name().name() == "__annotations__"); - TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1); - attributes.push_back(assign); - } break; - default: { + is_module, "Buffers only exist on modules at the moment"); + const auto buffer_list = ListLiteral(assign.rhs().get()).inputs(); + for (const auto& buffer : buffer_list) { + buffer_names.insert(StringLiteral(buffer).text()); + } + } else if (name == "__forward_pre_hooks__") { TORCH_INTERNAL_ASSERT( - false, - "Unexpected statement kind in module metadata: ", - kindToString(statement.kind())); + is_module, + "Forward pre hooks only exist on modules at the moment"); + const auto pre_hook_list = + ListLiteral(assign.rhs().get()).inputs(); + for (const auto& pre_hook : pre_hook_list) { + std::string pre_hook_name = StringLiteral(pre_hook).text(); + pre_hook_names.insert(pre_hook_name); + pre_hooks_order.emplace_back(pre_hook_name); + } + } else if (name == "__forward_hooks__") { + TORCH_INTERNAL_ASSERT( + is_module, + "Forward hooks only exist on modules at the moment"); + const auto hook_list = ListLiteral(assign.rhs().get()).inputs(); + for (const auto& hook : hook_list) { + std::string hook_name = StringLiteral(hook).text(); + hook_names.insert(hook_name); + hooks_order.emplace_back(hook_name); + } + } else { + if (auto fixed_up = attributeAssignmentSpecialHandlingHack( + qualified_classname, assign)) { + attributes.push_back(std::move(*fixed_up)); + } else if (assign.rhs().present()) { + // This is a constant assignment, of the form: + // foo : Final[int] = 3 + constants.push_back(assign); + } else { + // This is a regular attribute assignment, of the form: + // foo : Tensor + attributes.push_back(assign); + } } + } break; + case TK_SUBSCRIPT: { + // This is a special attribute assignment where the attribute + // is not a valid python, identifier. Looks like: + // __annotations__["0"] = Tensor + const auto lhs = Subscript(assign.lhs()); + TORCH_INTERNAL_ASSERT( + Var(lhs.value()).name().name() == "__annotations__"); + TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1); + attributes.push_back(assign); + } break; + default: { + TORCH_INTERNAL_ASSERT( + false, + "Unexpected statement kind in module metadata: ", + kindToString(statement.kind())); } - } break; - case TK_DEF: { - Def def = Def(statement); - if (pre_hook_names.find(def.name().name()) != pre_hook_names.end()) { - pre_hook_def_map.emplace(def.name().name(), def); - pre_hook_resolver_map.emplace( - def.name().name(), shared_from_this()); - } else if (hook_names.find(def.name().name()) != hook_names.end()) { - hook_def_map.emplace(def.name().name(), def); - hook_resolver_map.emplace(def.name().name(), shared_from_this()); - } else { - methods.emplace_back(def); - method_resolvers.push_back(shared_from_this()); - } - } break; - default: { - TORCH_INTERNAL_ASSERT( - false, - "Unexpected statement kind in class body: ", - kindToString(statement.kind())); } - } - } - - // Populate class attributes - ScriptTypeParser type_parser(shared_from_this()); - for (const auto& assign : attributes) { - switch (assign.lhs().kind()) { - case TK_VAR: { - const auto name = Var(assign.lhs()).name().name(); - TORCH_INTERNAL_ASSERT(name != "__parameters__"); - const auto type = type_parser.parseTypeFromExpr(assign.type().get()); - const bool is_parameter = parameter_names.count(name); - const bool is_buffer = buffer_names.count(name); - class_type->addAttribute(name, type, is_parameter, is_buffer); - } break; - case TK_SUBSCRIPT: { - const auto name = - StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]) - .text(); - const auto type = type_parser.parseTypeFromExpr(assign.rhs().get()); - const bool is_parameter = parameter_names.count(name); - const bool is_buffer = buffer_names.count(name); - class_type->addAttribute(name, type, is_parameter, is_buffer); + } break; + case TK_DEF: { + Def def = Def(statement); + const auto def_name = def.name().name(); + if (pre_hook_names.find(def_name) != pre_hook_names.end()) { + pre_hook_def_map.emplace(def_name, def); + pre_hook_resolver_map.emplace(def_name, shared_from_this()); + } else if (hook_names.find(def_name) != hook_names.end()) { + hook_def_map.emplace(def_name, def); + hook_resolver_map.emplace(def_name, shared_from_this()); + } else { + methods.emplace_back(def); + method_resolvers.push_back(shared_from_this()); } + } break; + default: { + TORCH_INTERNAL_ASSERT( + false, + "Unexpected statement kind in class body: ", + kindToString(statement.kind())); } } + } - // Populate class constants - for (const auto& assign : constants) { - auto const_val = type_parser.parseClassConstant(assign); - const auto name = Var(assign.lhs()).name().name(); - class_type->addConstant(name, const_val); + // Populate class attributes + ScriptTypeParser type_parser(shared_from_this()); + for (const auto& assign : attributes) { + switch (assign.lhs().kind()) { + case TK_VAR: { + const auto name = Var(assign.lhs()).name().name(); + TORCH_INTERNAL_ASSERT(name != "__parameters__"); + const auto type = type_parser.parseTypeFromExpr(assign.type().get()); + const bool is_parameter = parameter_names.count(name); + const bool is_buffer = buffer_names.count(name); + class_type->addAttribute(name, type, is_parameter, is_buffer); + } break; + case TK_SUBSCRIPT: { + const auto name = + StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0]).text(); + const auto type = type_parser.parseTypeFromExpr(assign.rhs().get()); + const bool is_parameter = parameter_names.count(name); + const bool is_buffer = buffer_names.count(name); + class_type->addAttribute(name, type, is_parameter, is_buffer); + } } + } - // build pre hook and hook def/resolver pairs - // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks() - // ordering here is call order for hooks - std::vector hooks; - std::vector hook_resolvers; - for (const std::string& hook_name : hooks_order) { - hooks.emplace_back(hook_def_map.find(hook_name)->second); - hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second); - } - std::vector pre_hooks; - std::vector pre_hook_resolvers; - for (const std::string& pre_hook_name : pre_hooks_order) { - pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second); - pre_hook_resolvers.push_back( - pre_hook_resolver_map.find(pre_hook_name)->second); - } + // Populate class constants + for (const auto& assign : constants) { + auto const_val = type_parser.parseClassConstant(assign); + const auto name = Var(assign.lhs()).name().name(); + class_type->addConstant(name, const_val); + } - cu_->register_type(class_type); - const auto self = SimpleSelf(class_type); - cu_->define( - qualified_classname, - /*properties=*/{}, - /*propResolvers=*/{}, - methods, - method_resolvers, - &self); - cu_->define_hooks( - qualified_classname, - hooks, - hook_resolvers, - pre_hooks, - pre_hook_resolvers, - &self); - } - - void importEnum( - const QualifiedName& qualified_name, - const ClassDef& enum_def) { - std::vector names_values; - - TypePtr value_type = nullptr; - auto set_or_check_type = [&value_type]( - const TypePtr& t, const SourceRange& loc) { - if (!value_type) { - value_type = t; - } else if (value_type != t) { - throw ErrorReport(loc) - << "Enum class with varying value types are not supported."; - } - }; + // build pre hook and hook def/resolver pairs + // pairs are dedupped in ir_emitter.cpp's CompilationUnit::define_hooks() + // ordering here is call order for hooks + std::vector hooks; + std::vector hook_resolvers; + for (const std::string& hook_name : hooks_order) { + hooks.emplace_back(hook_def_map.find(hook_name)->second); + hook_resolvers.push_back(hook_resolver_map.find(hook_name)->second); + } + std::vector pre_hooks; + std::vector pre_hook_resolvers; + for (const std::string& pre_hook_name : pre_hooks_order) { + pre_hooks.emplace_back(pre_hook_def_map.find(pre_hook_name)->second); + pre_hook_resolvers.push_back( + pre_hook_resolver_map.find(pre_hook_name)->second); + } - for (const auto& statement : enum_def.body()) { - if (statement.kind() != TK_ASSIGN) { - throw ErrorReport(statement.range()) - << "Unexpected statement in Enum class body: " - "only enum attribute definitions are currently supported."; - } + cu_->register_type(class_type); + const auto self = SimpleSelf(class_type); + cu_->define( + qualified_classname, + /*properties=*/{}, + /*propResolvers=*/{}, + methods, + method_resolvers, + &self); + cu_->define_hooks( + qualified_classname, + hooks, + hook_resolvers, + pre_hooks, + pre_hook_resolvers, + &self); +} - const auto assign = Assign(statement); - const auto name = Var(assign.lhs()).name().name(); - - IValue ivalue; - auto rhs = assign.rhs().get(); - switch (rhs.kind()) { - case TK_STRINGLITERAL: - ivalue = IValue(StringLiteral(rhs).text()); - set_or_check_type(StringType::get(), statement.range()); - break; - case TK_CONST: { - auto numeric_const = Const(rhs); - if (numeric_const.isFloatingPoint()) { - ivalue = IValue(numeric_const.asFloatingPoint()); - set_or_check_type(FloatType::get(), statement.range()); - } else if (numeric_const.isIntegral()) { - ivalue = IValue(numeric_const.asIntegral()); - set_or_check_type(IntType::get(), statement.range()); - } - break; - } - default: - throw ErrorReport(rhs.range()) - << "Unsupported enum value type: " << rhs.kind() - << ". Only Integers, Floats and Strings are supported."; - } +void SourceImporterImpl::importEnum( + const QualifiedName& qualified_name, + const ClassDef& enum_def) { + std::vector names_values; - names_values.emplace_back(std::make_pair(name, ivalue)); + TypePtr value_type = nullptr; + auto set_or_check_type = [&value_type]( + const TypePtr& t, const SourceRange& loc) { + if (!value_type) { + value_type = t; + } else if (value_type != t) { + throw ErrorReport(loc) + << "Enum class with varying value types are not supported."; } + }; - if (!value_type) { - throw ErrorReport(enum_def.range()) - << "No enum values defined for " << qualified_name.qualifiedName(); + for (const auto& statement : enum_def.body()) { + if (statement.kind() != TK_ASSIGN) { + throw ErrorReport(statement.range()) + << "Unexpected statement in Enum class body: " + "only enum attribute definitions are currently supported."; } - auto enum_type = EnumType::create( - qualified_name, std::move(value_type), std::move(names_values), cu_); - cu_->register_type(enum_type); - } - - void importNamedTuple( - const QualifiedName& qualified_name, - const ClassDef& named_tuple_def) { - ScriptTypeParser type_parser(shared_from_this()); - std::vector field_names; - std::vector field_types; - std::vector field_defaults; - for (const auto& statement : named_tuple_def.body()) { - if (statement.kind() != TK_ASSIGN) { - throw ErrorReport(statement.range()) - << "Unexpected statement in NamedTuple body: " - "only attribute annotations are currently supported."; - } - const auto assign = Assign(statement); - - auto name = Var(Assign(statement).lhs()).name().name(); - c10::optional default_val; - if (assign.rhs().present()) { - std::vector parsed = type_parser.evaluateDefaults( - assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()}); - TORCH_INTERNAL_ASSERT(parsed.size() == 1); - default_val = parsed[0]; + const auto assign = Assign(statement); + const auto name = Var(assign.lhs()).name().name(); + + IValue ivalue; + auto rhs = assign.rhs().get(); + switch (rhs.kind()) { + case TK_STRINGLITERAL: + ivalue = IValue(StringLiteral(rhs).text()); + set_or_check_type(StringType::get(), statement.range()); + break; + case TK_CONST: { + auto numeric_const = Const(rhs); + if (numeric_const.isFloatingPoint()) { + ivalue = IValue(numeric_const.asFloatingPoint()); + set_or_check_type(FloatType::get(), statement.range()); + } else if (numeric_const.isIntegral()) { + ivalue = IValue(numeric_const.asIntegral()); + set_or_check_type(IntType::get(), statement.range()); + } + break; } + default: + throw ErrorReport(rhs.range()) + << "Unsupported enum value type: " << rhs.kind() + << ". Only Integers, Floats and Strings are supported."; + } - auto type = type_parser.parseTypeFromExpr(assign.type().get()); + names_values.emplace_back(std::make_pair(name, ivalue)); + } - field_names.emplace_back(std::move(name)); - field_types.emplace_back(std::move(type)); - if (default_val) { - field_defaults.emplace_back(std::move(*default_val)); - } + if (!value_type) { + throw ErrorReport(enum_def.range()) + << "No enum values defined for " << qualified_name.qualifiedName(); + } + + auto enum_type = EnumType::create( + qualified_name, std::move(value_type), std::move(names_values), cu_); + cu_->register_type(enum_type); +} + +void SourceImporterImpl::importNamedTuple( + const QualifiedName& qualified_name, + const ClassDef& named_tuple_def) { + ScriptTypeParser type_parser(shared_from_this()); + std::vector field_names; + std::vector field_types; + std::vector field_defaults; + for (const auto& statement : named_tuple_def.body()) { + if (statement.kind() != TK_ASSIGN) { + throw ErrorReport(statement.range()) + << "Unexpected statement in NamedTuple body: " + "only attribute annotations are currently supported."; + } + const auto assign = Assign(statement); + + auto name = Var(Assign(statement).lhs()).name().name(); + c10::optional default_val; + if (assign.rhs().present()) { + std::vector parsed = type_parser.evaluateDefaults( + assign.rhs().range(), {assign.rhs().get()}, {assign.type().get()}); + TORCH_INTERNAL_ASSERT(parsed.size() == 1); + default_val = parsed[0]; } - auto tt = TupleType::createNamed( - qualified_name, field_names, field_types, field_defaults); - cu_->register_type(tt); - } + auto type = type_parser.parseTypeFromExpr(assign.type().get()); - void parsePossibleVersionNumber(Lexer& L) { - // Older versions of serialization produced an op_version_set string - // per-file We now just use a single version which is handled by - // PyTorchStreamReader. We used to check if op_version_set was _newer_ for - // forward compatibility reasons but now that it doesn't exist there can't - // be a newer one, so we just discard this. - if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") { - auto range = L.cur().range; - L.next(); - L.expect('='); - std::string version_text = L.expect(TK_NUMBER).text(); - L.expect(TK_NEWLINE); + field_names.emplace_back(std::move(name)); + field_types.emplace_back(std::move(type)); + if (default_val) { + field_defaults.emplace_back(std::move(*default_val)); } } - // older versions of serialization required import statements, - // and defined classes file-at-a-time in import order. - // The problem is that in Python - // it is possible to construct cyclic dependencies between files even - // when there are none between individual classes. New versions of loading - // just compile class-at-a-time, so we no longer need to follow the import - // order. Future serialization may stop producing the import code. - void parseImports(Lexer& L) { - while (L.nextIf(TK_IMPORT)) { - std::ostringstream s; - while (L.cur().kind != TK_NEWLINE) { - s << L.cur().text(); - L.next(); - } - L.expect(TK_NEWLINE); - } + auto tt = TupleType::createNamed( + qualified_name, field_names, field_types, field_defaults); + cu_->register_type(tt); +} + +void SourceImporterImpl::parsePossibleVersionNumber(Lexer& L) { + // Older versions of serialization produced an op_version_set string + // per-file We now just use a single version which is handled by + // PyTorchStreamReader. We used to check if op_version_set was _newer_ for + // forward compatibility reasons but now that it doesn't exist there can't + // be a newer one, so we just discard this. + if (L.cur().kind == TK_IDENT && L.cur().text() == "op_version_set") { + auto range = L.cur().range; + L.next(); + L.expect('='); + std::string version_text = L.expect(TK_NUMBER).text(); + L.expect(TK_NEWLINE); } +} - std::shared_ptr cu_; - std::unordered_map> env_; - SourceLoader source_loader_; - std::unordered_set loaded_sources_; - // named types and functions loaded from a file but not yet defined because - // their type has not been requested yet. - std::unordered_map to_be_defined_; -}; +// older versions of serialization required import statements, +// and defined classes file-at-a-time in import order. +// The problem is that in Python +// it is possible to construct cyclic dependencies between files even +// when there are none between individual classes. New versions of loading +// just compile class-at-a-time, so we no longer need to follow the import +// order. Future serialization may stop producing the import code. +void SourceImporterImpl::parseImports(Lexer& L) { + while (L.nextIf(TK_IMPORT)) { + std::ostringstream s; + while (L.cur().kind != TK_NEWLINE) { + s << L.cur().text(); + L.next(); + } + L.expect(TK_NEWLINE); + } +} std::shared_ptr ClassNamespaceValue::attr( const SourceRange& loc, diff --git a/torch/csrc/jit/serialization/import_source.h b/torch/csrc/jit/serialization/import_source.h index e87ab59271594..f52f38afe6b15 100644 --- a/torch/csrc/jit/serialization/import_source.h +++ b/torch/csrc/jit/serialization/import_source.h @@ -1,22 +1,79 @@ #pragma once +#include +#include #include +#include +#include +#include #include +#include +#include #include #include +#include #include #include namespace torch { namespace jit { -struct SourceImporterImpl; +using SourceLoader = std::function(const std::string&)>; + +struct SourceImporterImpl : public Resolver, + std::enable_shared_from_this { + SourceImporterImpl( + std::shared_ptr cu, + const std::vector* constant_table, + SourceLoader source_loader, + size_t version); + TypePtr findNamedType(const QualifiedName& name); + Function* findFunction(const QualifiedName& name); + void parseSourceIfNeeded(const std::string& qualifier); + void LEGACY_import_methods( + const Module& mod, + const std::shared_ptr& src); + + std::shared_ptr resolveValue( + const std::string& name, + Function& m, + const SourceRange& loc) override; + TypePtr resolveType(const std::string& name, const SourceRange& loc) override; + + private: + void importFunction(const std::string& qualifier, const Def& def); + void importNamedType(const std::string& qualifier, const ClassDef& class_def); + c10::optional attributeAssignmentSpecialHandlingHack( + const QualifiedName& qualified_classname, + const Assign& assign); + void importClass( + const QualifiedName& qualified_classname, + const ClassDef& class_def, + bool is_module); + void importEnum( + const QualifiedName& qualified_name, + const ClassDef& enum_def); + void importNamedTuple( + const QualifiedName& qualified_name, + const ClassDef& named_tuple_def); + + void parsePossibleVersionNumber(Lexer& L); + + void parseImports(Lexer& L); + + std::shared_ptr cu_; + std::unordered_map> env_; + SourceLoader source_loader_; + std::unordered_set loaded_sources_; + // named types and functions loaded from a file but not yet defined because + // their type has not been requested yet. + std::unordered_map to_be_defined_; +}; // Given a directory of serialized TorchScript sources, // This class allows the loading of individual named types in source. // Resolves the dependencies between source files and parses // the source files as necessary. -using SourceLoader = std::function(const std::string&)>; struct TORCH_API SourceImporter { SourceImporter( diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 4a4e8663b3838..f465eaf4dff00 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -353,6 +353,44 @@ void Pickler::pushTensor(const IValue& ivalue) { } } +void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) { + pushGlobal("torch._utils", "_rebuild_sparse_tensor"); + push(PickleOpCode::MARK); + // layout + auto layout = static_cast(tensor.layout()); + pushInt(layout); + switch (layout) { + case static_cast(c10::Layout::Sparse): + // size + push(PickleOpCode::MARK); + for (auto size : tensor.sizes()) { + pushInt(size); + } + push(PickleOpCode::TUPLE); + // requires grad + pushIValue(tensor.requires_grad()); + // indices + pushTensor(tensor._indices()); + // values + pushTensor(tensor._values()); + break; + default: + TORCH_CHECK( + false, + "Unsupported sparse tensor layout type in serialization ", + static_cast(layout)); + break; + } + // backward_hooks + pushGlobal("collections", "OrderedDict"); + push(PickleOpCode::EMPTY_TUPLE); + // Construct the collections.OrderedDict for the backward_hooks + push(PickleOpCode::REDUCE); + push(PickleOpCode::TUPLE); + // Call torch._utils._rebuild_sparse_coo_tensor + push(PickleOpCode::REDUCE); +} + void Pickler::pushLiteralTensor(const IValue& ivalue) { // In contrast to tensor references, literal tensors are included in the // pickle program binary blob. They are written to the file after the STOP @@ -362,6 +400,12 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. auto& tensor = ivalue.toTensor(); + + if (tensor.is_sparse() || tensor.is_sparse_csr()) { + pushLiteralSparseTensor(tensor); + return; + } + bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index ac54ac45a2886..3dc6bef9d9131 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -172,6 +172,7 @@ class TORCH_API Pickler { void pushTensor(const IValue& ivalue); void pushTensorReference(const IValue& ivalue); void pushLiteralTensor(const IValue& ivalue); + void pushLiteralSparseTensor(const at::Tensor& tensor); void pushTuple(const IValue& ivalue); void pushString(const std::string& string); void pushDevice(const IValue& ivalue); diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 1ab968967392f..6b1bf15304624 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -511,13 +511,31 @@ struct PythonPrintImpl { } indent(); printValueList(body_, lhs); + // We need to preserve Union/Optional type annotations, but only if + // we're not assigning values as part of a tuple unpacking statement + // (Python doesn't allow type annotations in multiple assignment) + if (lhs.size() == 1) { + Value* v = lhs.at(0); + if (!annotated_unions_.count(v) && !expr_table_.count(v) && + (v->type()->kind() == UnionType::Kind || + v->type()->kind() == OptionalType::Kind)) { + body_ << " : " << v->type()->annotation_str(); + annotated_unions_.insert(v); + } + } body_ << " = "; + // or if value is being assigned to something of a union type printValueList(body_, rhs); body_ << "\n"; } bool requiresAnnotation(Value* lhs, Value* rhs) { - return *lhs->type() != *rhs->type(); + if (lhs->type()->kind() == UnionType::Kind || + lhs->type()->kind() == OptionalType::Kind) { + return annotated_unions_.insert(lhs).second; + } else { + return *lhs->type() != *rhs->type(); + } } void printAnnotatedAssignment( @@ -1162,23 +1180,47 @@ struct PythonPrintImpl { // calculate how many args are specified. // see (https://github.com/pytorch/pytorch/pull/56079) for more // details. - size_t necessary_args = - CalculateNecessaryArgs(schema.arguments(), node->inputs()); - for (const auto i : c10::irange(necessary_args)) { - if (i > 0) + size_t num_schema_args = schema.arguments().size(); + + // we only want to do this extra logic only when necessary. + if (num_schema_args > 0) { + // calculate how many args are specified. + // see (https://github.com/pytorch/pytorch/pull/56079) for more + // details. + auto specified_args = + CalculateNecessaryArgs(schema.arguments(), node->inputs(), true); + + auto num_necessary = specified_args.first; + auto num_out = specified_args.second; + + for (size_t i = 0; i < num_necessary; ++i) { + if (i > 0) + stmt << ", "; + auto v = useOf(node->inputs().at(i)); + // print the kwarg name if it is a kwarg only argument. + if (i < num_schema_args) { + auto arg = schema.arguments().at(i); + if (arg.kwarg_only()) { + stmt << arg.name() << "="; + } + } else { + // vararg functions like format can have extra arguments + AT_ASSERT(schema.is_vararg()); + } + stmt << *v; + } + + // print out args + for (size_t i = num_schema_args - num_out; i < num_schema_args; i++) { stmt << ", "; - auto v = useOf(node->inputs().at(i)); - // print the kwarg name if it is a kwarg only argument. - if (i < schema.arguments().size()) { auto arg = schema.arguments().at(i); - if (arg.kwarg_only()) { - stmt << arg.name() << "="; + TORCH_INTERNAL_ASSERT(arg.is_out()); + // figure out the corresponding input at this index + auto input_idx = node->inputs().size() - (num_schema_args - i); + if (input_idx < node->inputs().size()) { + stmt << arg.name() << "=" << *useOf(node->inputs().at(input_idx)); } - } else { - // vararg functions like format can have extra arguments - AT_ASSERT(schema.is_vararg()); } - stmt << *v; } stmt << ")"; } break; @@ -1278,10 +1320,12 @@ struct PythonPrintImpl { body_ << arg_name; if (print_first_argument_type) { body_ << ": " << arg.type()->annotation_str(type_printer_); + annotated_unions_.insert(*param_it); } } else { body_ << ",\n " << arg_name << ": " << arg.type()->annotation_str(type_printer_); + annotated_unions_.insert(*param_it); } if (arg.default_value()) { printDefaultValue(arg, body_, *arg.default_value()); @@ -1535,6 +1579,12 @@ struct PythonPrintImpl { // table. PrintDepsTable& deps_table_; + // We need to preserve Union/Optional type annotations, but we should + // only print the annotation on variable declaration (not on any + // following uses). This set tracks the Value*s that we've already + // printed with annotations + std::unordered_set annotated_unions_; + // A function that, given a named type, returns us the correct string to print // for it. c10::TypePrinter type_printer_; diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 581b94978c459..e0e556ecbbde3 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -23,8 +23,8 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) { // Pickled objects are stored in a form compatible with Python pickling. // In torchscript List[T]/Dict[K, V] are statically typed and contain -// dynamic type tags allow T, K, and V to be recovered. But this info -// is not stored in the Python pickling information. However, we +// dynamic type tags that allow T, K, and V to be recovered. But this +// info is not stored in the Python pickling information. However, we // can recover this information from the static type of the top-level // object being unpickled, because we have a record of the type of the // objects it contains as attributes. @@ -108,6 +108,19 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { to_process.emplace_back(std::move(elem)); } } break; + case UnionType::Kind: { + auto t = w.static_type->expect(); + if (t->containedTypes().size() == 2 && + t->canHoldType(NoneType::get())) { + if (!w.value.isNone()) { + auto inner = t->containedTypes()[0] != NoneType::get() + ? t->containedTypes()[0] + : t->containedTypes()[1]; + Work elem = {inner, w.value}; + to_process.emplace_back(std::move(elem)); + } + } + } break; case ListType::Kind: { // specialized lists do not need their type refined, so we can exit // early here @@ -318,7 +331,7 @@ PickleOpCode Unpickler::readInstruction() { tuple->elements().reserve(stack_.size() - start); auto start_it = stack_.begin() + start; for (auto it = start_it; it != stack_.end(); ++it) { - tuple->elements().emplace_back(*it); + tuple->elements().emplace_back(std::move(*it)); } stack_.erase(start_it, stack_.end()); stack_.emplace_back(std::move(tuple)); @@ -550,6 +563,9 @@ void Unpickler::readGlobal( // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); + } else if ( + module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { + rebuildSparseTensor(); } else if (module_name == "builtins" && class_name == "complex") { globals_.emplace_back([this] { auto elems = pop(stack_).toTuple()->elements(); @@ -647,6 +663,38 @@ void Unpickler::readGlobal( stack_.emplace_back(int64_t(globals_.size() - 1)); } +void Unpickler::rebuildSparseTensor() { + globals_.emplace_back([this] { + auto tup = pop(stack_).toTuple(); + const auto& elements = tup->elements(); + size_t idx = 0; + auto layout = elements.at(idx++).toInt(); + at::Tensor result; + switch (layout) { + case static_cast(c10::Layout::Sparse): { + std::vector size = tupleToIntList(elements.at(idx++)); + bool requires_grad = elements.at(idx++).toBool(); + auto& indices_tensor = elements.at(idx++).toTensor(); + auto& values_tensor = elements.at(idx++).toTensor(); + auto options = values_tensor.options() + .layout(c10::Layout::Sparse) + .requires_grad(requires_grad); + result = at::_sparse_coo_tensor_unsafe( + indices_tensor, values_tensor, size, options); + result = autograd::make_variable(result, options.requires_grad()); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported sparse tensor layout type in serialization ", + static_cast(layout)); + break; + } + stack_.emplace_back(std::move(result)); + }); +} + void Unpickler::rebuildTensor(bool quantized) { globals_.emplace_back([this, quantized] { auto tup = pop(stack_).toTuple(); diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index f404deee848be..586ff9cc4ae59 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -108,6 +108,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED void rebuildRRef(); #endif diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 1ae3330799c64..b42d37428208b 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -16,6 +16,8 @@ std::string blockDtypeCppString(const Dtype& dtype) { return "1"; case ScalarType::Half: return "2"; + case ScalarType::BFloat16: + return "2"; // NOLINTNEXTLINE(bugprone-branch-clone) case ScalarType::Char: return "1"; @@ -76,7 +78,7 @@ void BlockAnalysis::visit(ForPtr v) { v->body()->accept(this); } else if (loop_options.is_gpu_thread_index()) { auto block_size = v->stop(); - block_size_ = to(block_size)->value(); + block_size_ = *intValue(block_size); v->body()->accept(this); } else { IRVisitor::visit(v); @@ -185,15 +187,14 @@ void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { // The dims for the multi-dim tensors for (unsigned long d = 0; d < num_dims; d++) { - auto dim_val = to(multidimbuf->dim(d)); - this->dim_values_map.emplace(this->dim_names[d], dim_val->value()); + auto dim_val = *intValue(multidimbuf->dim(d)); + this->dim_values_map.emplace(this->dim_names[d], dim_val); } // The dimensions for the flattened tensors - auto val = to(buf->dim(0)); + auto val = *intValue(buf->dim(0)); if (block_analysis_->is_buf_store_target(buf)) { - this->dim_values_map.emplace( - this->flat_dim_names[num_dims - 1], val->value()); + this->dim_values_map.emplace(this->flat_dim_names[num_dims - 1], val); } } diff --git a/torch/csrc/jit/tensorexpr/bounds_inference.cpp b/torch/csrc/jit/tensorexpr/bounds_inference.cpp index 55dbacf087899..649fd0e69da8e 100644 --- a/torch/csrc/jit/tensorexpr/bounds_inference.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_inference.cpp @@ -185,7 +185,7 @@ std::vector getBoundExtents( std::vector extents; for (size_t i = 0; i < starts.size(); ++i) { ExprPtr dim = IRSimplifier::simplify( - alloc(alloc(stops[i], starts[i]), alloc(1))); + alloc(alloc(stops[i], starts[i]), immLike(stops[i], 1))); extents.push_back(dim); } diff --git a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp index 4ac5c6b96fb9a..fdfff12ad7666 100644 --- a/torch/csrc/jit/tensorexpr/bounds_overlap.cpp +++ b/torch/csrc/jit/tensorexpr/bounds_overlap.cpp @@ -130,8 +130,8 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { auto vars = VarFinder::find(lowDiff); if (vars.size() == 1) { lowDiff = IRSimplifier::simplify(alloc( - SubstituteInClone(b.start, {{*vars.begin(), alloc(1)}}), - SubstituteInClone(a.start, {{*vars.begin(), alloc(1)}}))); + SubstituteInClone(b.start, {{*vars.begin(), immLike(b.start, 1)}}), + SubstituteInClone(a.start, {{*vars.begin(), immLike(a.start, 1)}}))); } } @@ -139,8 +139,8 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { auto vars = VarFinder::find(highDiff); if (vars.size() == 1) { highDiff = IRSimplifier::simplify(alloc( - SubstituteInClone(b.end, {{*vars.begin(), alloc(1)}}), - SubstituteInClone(a.end, {{*vars.begin(), alloc(1)}}))); + SubstituteInClone(b.end, {{*vars.begin(), immLike(b.end, 1)}}), + SubstituteInClone(a.end, {{*vars.begin(), immLike(a.end, 1)}}))); } } @@ -157,12 +157,13 @@ std::vector subtractBound(Bound a, Bound b, OverlapKind overlap) { if (hasHead) { res.emplace_back( - a.start, IRSimplifier::simplify(alloc(b.start, alloc(1)))); + a.start, + IRSimplifier::simplify(alloc(b.start, immLike(b.start, 1)))); } if (hasTail) { ExprPtr tailStart = - IRSimplifier::simplify(alloc(b.end, alloc(1))); + IRSimplifier::simplify(alloc(b.end, immLike(b.end, 1))); res.emplace_back(tailStart, a.end); } diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index 0bbc3378b0323..b2b077b9771d1 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -67,7 +67,7 @@ void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) { case ScalarType::Name: \ return callArg.Name##Ptr(); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 77ba8e173631e..0504f9a8b0b0b 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -46,6 +46,10 @@ class TORCH_API CodeGen { stmt_ = stmt_->accept_mutator(mutator); } + void apply_visitor(IRVisitor* visitor) { + stmt_->accept(visitor); + } + std::vector& buffer_args() { return buffer_args_; } @@ -104,7 +108,7 @@ class TORCH_API CodeGen { class CodeGen::BufferArg { public: BufferArg(const Placeholder& buffer) : buf_(buffer.data()) {} - BufferArg(Tensor* tensor) : buf_(tensor->buf()) {} + BufferArg(Tensor tensor) : buf_(tensor.buf()) {} BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {} BufferArg(const BufHandle& buf) : buf_(buf.node()) {} @@ -149,7 +153,7 @@ class CodeGen::CallArg { memcpy(&data_, &v, sizeof(Type)); \ } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_TYPE_CTOR); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_TYPE_CTOR); #undef ARG_TYPE_CTOR void* data() const { @@ -161,7 +165,7 @@ class CodeGen::CallArg { return (Type*)&data_; \ } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_PTR_DEFINE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ARG_PTR_DEFINE); #undef ARG_PTR_DEFINE private: diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp index 39a5615a97545..6c02f7f7e09df 100644 --- a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp @@ -1,44 +1,406 @@ +#include +#include +#include + #include +#include +#include +#include namespace torch { namespace jit { namespace tensorexpr { -void CppPrinter::visit(AllocatePtr alloc) { - constexpr size_t kAllocOnStackThresholdSize = 512; +// Rewrites the variables' name according to valid C++ naming convention. +// E.g. in Graph IR, variable name may contain '.', in C++, they are replaced +// with '_'. +class CppVarNameRewriter : public IRVisitor { + public: + void visit(VarPtr v) override { + constexpr char kDot = '.'; + constexpr char kUnderscore = '_'; + if (v->name_hint().find(kDot) == std::string::npos) { + return; + } + std::string name = v->name_hint(); + std::replace(name.begin(), name.end(), kDot, kUnderscore); + v->set_name_hint(std::move(name)); + } + + void visit(BufPtr v) override { + v->base_handle()->accept(this); + } +}; + +static std::string declareExternalFunction(const std::string& func_name) { + return "void " + func_name + + "(" + "int64_t bufs_num, " + "void** buf_data, " + "int64_t* buf_ranks, " + "int64_t* buf_dims, " + "int8_t* buf_dtypes, " + "int64_t args_num, " + "int64_t* extra_args);"; +} + +CppPrinter::CppPrinter(std::ostream* os) : IRPrinter(*os), lane_(0) {} + +CppPrinter::~CppPrinter() = default; + +void CppPrinter::printPrologue() { + os() << "#include " << std::endl; + os() << "#include " << std::endl; + os() << "#include " << std::endl; + os() << "#include " << std::endl; + os() << std::endl; + + os() << "#define POS_INFINITY INFINITY" << std::endl; + os() << "#define NEG_INFINITY -INFINITY" << std::endl; + os() << std::endl; + + os() << cpp_intrinsics_definition << std::endl; + os() << std::endl; + + os() << "namespace torch {" << std::endl; + os() << "namespace jit {" << std::endl; + os() << "namespace tensorexpr {" << std::endl; + for (auto const& it : getNNCFunctionRegistry()) { + os() << declareExternalFunction(it.first) << std::endl; + } + os() << "} // namespace tensorexpr" << std::endl; + os() << "} // namespace jit" << std::endl; + os() << "} // namespace torch" << std::endl; + os() << std::endl; + + os() << "using namespace torch::jit::tensorexpr;" << std::endl; + os() << std::endl; +} + +template +inline typename std::enable_if::value, void>::type +visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << *lhs << " % " << *rhs; +} + +template +inline typename std::enable_if::value, void>::type +visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << "std::fmod(" << *lhs << ", " << *rhs << ")"; +} + +template +inline typename std::enable_if< + std::is_floating_point::value || std::is_integral::value, + void>::type +visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << "std::max(" << *lhs << ", " << *rhs << ")"; +} - size_t size = 1; - for (auto dim : alloc->dims()) { - IntImmPtr v = to(dim); - if (v) { - size *= v->value(); +template +inline typename std::enable_if< + !std::is_floating_point::value && !std::is_integral::value, + void>::type +visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << "(" << *lhs << " < " << *rhs << ") ? " << *rhs << " : " << *lhs; +} + +template +inline typename std::enable_if< + std::is_floating_point::value || std::is_integral::value, + void>::type +visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << "std::min(" << *lhs << ", " << *rhs << ")"; +} + +template +inline typename std::enable_if< + !std::is_floating_point::value && !std::is_integral::value, + void>::type +visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) { + os << *lhs << " < " << *rhs << " ? " << *lhs << " : " << *rhs; +} + +template +void visit_binary_op( + std::ostream& os, + const ExprPtr lhs, + const ExprPtr rhs, + IRNodeType op_type) { + switch (op_type) { + case IRNodeType::kMod: + visit_mod(os, lhs, rhs); + break; + case IRNodeType::kMax: + visit_max(os, lhs, rhs); + break; + case IRNodeType::kMin: + visit_min(os, lhs, rhs); + break; + default: + throw std::runtime_error("invalid op type"); + } +} + +template +void dispatch_binary_op(std::ostream& os, const BinaryOpNode* v) { + switch (v->lhs()->dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + visit_binary_op(os, v->lhs(), v->rhs(), v->expr_type()); \ + break; + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } +} + +void CppPrinter::visit(RampPtr v) { + visit(alloc(v->base(), alloc(alloc(lane_), v->stride()))); +} + +void CppPrinter::visit(BroadcastPtr v) { + v->value()->accept(this); +} + +void CppPrinter::visit(ModPtr v) { + dispatch_binary_op(os(), v.get()); +} + +void CppPrinter::visit(MaxPtr v) { + dispatch_binary_op(os(), v.get()); +} + +void CppPrinter::visit(MinPtr v) { + dispatch_binary_op(os(), v.get()); +} + +void CppPrinter::visit(CompareSelectPtr v) { + os() << "((" << *v->lhs() << " " + << IRPrinter::to_string(v->compare_select_op()) << " " << *v->rhs() + << ") ? " << *v->ret_val1() << " : " << *v->ret_val2() << ")"; +} + +void CppPrinter::visit(IfThenElsePtr v) { + os() << "((" << *v->condition() << ") ? " << *v->true_value() << " : " + << *v->false_value() << ")"; +} + +void CppPrinter::visit(AllocatePtr v) { + size_t size = v->dtype().byte_size(); + for (const auto& dim : v->dims()) { + IntImmPtr d = to(dim); + if (d) { + size *= d->value(); } else { throw std::runtime_error("Only IntImm dimensions are supported for now"); } } emitIndent(); - if (size <= kAllocOnStackThresholdSize) { - os() << alloc->dtype().ToCppString() << " " << (*alloc->buffer_var()) << "[" - << size << "];" << std::endl; - } else { - size *= alloc->dtype().byte_size(); - os() << alloc->dtype().ToCppString() << "* " << (*alloc->buffer_var()) - << " = static_cast<" << alloc->dtype().ToCppString() << "*>(malloc(" - << size << "));" << std::endl; - allocated_on_heap_.insert(alloc->buffer_var()); + os() << v->dtype().ToCppString() << "* " << (*v->buffer_var()) + << " = static_cast<" << v->dtype().ToCppString() << "*>(malloc(" << size + << "));" << std::endl; +} + +void CppPrinter::visit(FreePtr v) { + emitIndent(); + os() << "free(" << *v->buffer_var() << ");" << std::endl; +} + +void CppPrinter::visit(LoadPtr v) { + auto flat_idx = flatten_index(v->buf()->dims(), v->indices()); + os() << *v->base_handle() << "[" << *flat_idx << "]"; +} + +void CppPrinter::visit(StorePtr v) { + auto flat_idx = flatten_index(v->buf()->dims(), v->indices()); + const int lanes = v->value()->dtype().lanes(); + for (int lane = 0; lane < lanes; lane++) { + lane_ = lane; + emitIndent(); + os() << *v->base_handle() << "[" << *flat_idx << "] = " << *v->value() + << ";" << std::endl; + } +} + +void CppPrinter::visit(CastPtr v) { + os() << "static_cast<" << v->dtype().ToCppString() << ">(" << *v->src_value() + << ")"; +} + +void CppPrinter::visit(BitCastPtr v) { + os() << "std::bitcast<" << v->src_value()->dtype().ToCppString() << ", " + << v->dtype().ToCppString() << ">(" << *v->src_value() << ")"; +} + +void CppPrinter::visit(IntrinsicsPtr v) { + if (v->op_type() == kRand || v->op_type() == kSigmoid) { + throw std::runtime_error("kRand and kSigmoid are not supported"); + } + + os() << "std::" << v->func_name() << "("; + for (int i = 0; i < v->nparams(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *v->param(i); } + os() << ")"; } -void CppPrinter::visit(FreePtr free) { - VarPtr var = free->buffer_var(); - if (allocated_on_heap_.count(var)) { +void CppPrinter::visit(ExternalCallPtr v) { + // The generated code needs to link against functions defined + // in external_functions.cpp. + + auto& func_registry = getNNCFunctionRegistry(); + if (!func_registry.count(v->func_name())) { + throw unimplemented_lowering(v); + } + + std::vector bufs(v->buf_args()); + bufs.insert(bufs.begin(), v->buf()); + auto for_buf = [&](const std::function& print_buf) { + for (size_t i = 0; i < bufs.size(); i++) { + if (i > 0) { + os() << ", "; + } + print_buf(bufs[i]); + } + }; + + emitIndent(); + os() << "{" << std::endl; + indent_++; + + emitIndent(); + os() << "void* buf_ptrs[]{"; + for_buf([&](const BufPtr b) { os() << *b->base_handle(); }); + os() << "};" << std::endl; + + emitIndent(); + os() << "int64_t buf_ranks[]{"; + for_buf([&](const BufPtr b) { os() << b->ndim(); }); + os() << "};" << std::endl; + + emitIndent(); + os() << "int64_t buf_dims[]{"; + for_buf([&](const BufPtr buf) { + for (size_t i = 0; i < buf->ndim(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *buf->dim(i); + } + }); + os() << "};" << std::endl; + + emitIndent(); + os() << "int8_t buf_dtypes[]{"; + for_buf([&](const BufPtr buf) { + os() << static_cast(buf->dtype().scalar_type()); + }); + os() << "};" << std::endl; + + emitIndent(); + os() << "int64_t extra_args[]{"; + for (size_t i = 0; i < v->args().size(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *v->args()[i]; + } + os() << "};" << std::endl; + + emitIndent(); + os() << v->func_name() << "(" << std::endl; + emitIndent(); + os() << " " << bufs.size() << "," << std::endl; + emitIndent(); + os() << " buf_ptrs," << std::endl; + emitIndent(); + os() << " buf_ranks," << std::endl; + emitIndent(); + os() << " buf_dims," << std::endl; + emitIndent(); + os() << " buf_dtypes," << std::endl; + emitIndent(); + os() << " " << v->args().size() << "," << std::endl; + emitIndent(); + os() << " extra_args);" << std::endl; + + indent_--; + emitIndent(); + os() << "}" << std::endl; +} + +void CppPrinter::visit(LetPtr v) { + if (v->dtype().lanes() == 1) { emitIndent(); - os() << "free(" << name_manager()->get_unique_name(var) << ");" - << std::endl; + os() << v->dtype().ToCppString() << " " << *v->var() << " = " << *v->value() + << ";" << std::endl; + } else { + vector_vars_[v->var()] = v->value(); + } +} + +void CppPrinter::visit(VarPtr v) { + if (v->dtype().lanes() == 1) { + os() << name_manager()->get_unique_name(v); + } else { + os() << *vector_vars_.at(v); } } +CppCodeGen::CppCodeGen( + StmtPtr stmt, + const std::vector& buffer_args, + at::Device device, + const std::string& kernel_func_name) + : CodeGen(stmt, buffer_args, device, kernel_func_name) { + init(); +} + +void CppCodeGen::init() { + printer_ = std::make_unique(&oss_); + var_name_rewriter_ = std::make_unique(); + + apply_visitor(var_name_rewriter_.get()); + + printer_->printPrologue(); + os() << "void " << kernel_func_name() << "("; + const std::vector buffer_args = this->buffer_args(); + for (size_t i = 0; i < buffer_args.size(); i++) { + if (i > 0) { + os() << ", "; + } + const BufferArg& buffer_arg = buffer_args[i]; + const VarPtr var = buffer_arg.var(); + Dtype dtype = buffer_arg.dtype(); + os() << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ") << *var; + } + os() << ")"; + stmt()->accept(printer_.get()); + os() << std::endl; +} + +CppCodeGen::~CppCodeGen() = default; + +void CppCodeGen::call(const std::vector& args) { + // TODO: compile the generated C++ kernel into a library, + // and call the library here. + os() << "int main() {}" << std::endl; +} + +void CppCodeGen::call_raw(const std::vector& args) { + // TODO: compile the generated C++ kernel into a library, + // and call the library here. + os() << "int main() {}" << std::endl; +} + +RegisterCodeGen cpp_codegen_reg("cpp_codegen"); + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.h b/torch/csrc/jit/tensorexpr/cpp_codegen.h index 1cf15658716e6..a6d583ed4efb7 100644 --- a/torch/csrc/jit/tensorexpr/cpp_codegen.h +++ b/torch/csrc/jit/tensorexpr/cpp_codegen.h @@ -1,24 +1,100 @@ #pragma once +#include #include -#include - namespace torch { namespace jit { namespace tensorexpr { +class CppVarNameRewriter; + // Generates C++ code from the IR. +// +// Vector operations are unrolled. +// For example: +// C[Ramp(0, 1, 3)] = A[Ramp(0, 2, 3)] + B[Ramp(0, 3, 3)]; +// is unrolled into: +// C[0] = A[0] + B[0]; +// C[1] = A[2] + B[3]; +// C[2] = A[4] + B[6]; class TORCH_API CppPrinter : public IRPrinter { public: - explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {} + explicit CppPrinter(std::ostream* os); + ~CppPrinter() override; + + void printPrologue(); using IRPrinter::visit; + + // Binary expressions. + void visit(ModPtr) override; + void visit(MaxPtr) override; + void visit(MinPtr) override; + + // Conditional expressions. + void visit(CompareSelectPtr) override; + void visit(IfThenElsePtr) override; + + // Tensor operations. void visit(AllocatePtr) override; void visit(FreePtr) override; + void visit(LoadPtr) override; + void visit(StorePtr) override; + + // Casts. + void visit(CastPtr) override; + void visit(BitCastPtr) override; + + // Calls. + void visit(IntrinsicsPtr) override; + void visit(ExternalCallPtr) override; + + // Vars. + void visit(LetPtr) override; + void visit(VarPtr) override; + + // Vector data types. + void visit(RampPtr) override; + void visit(BroadcastPtr) override; private: - std::unordered_set allocated_on_heap_; + int lane_; + std::unordered_map vector_vars_; +}; + +class TORCH_API CppCodeGen : public CodeGen { + public: + CppCodeGen( + StmtPtr stmt, + const std::vector& buffer_args, + at::Device device = at::kCPU, + const std::string& kernel_func_name = "func"); + + ~CppCodeGen() override; + + void call(const std::vector& args) override; + void call_raw(const std::vector& args) override; + + template + void operator()(const Ts&... ts) { + call(std::vector({CallArg(ts)...})); + } + + std::string getCodeText(const std::string& attr = "") override { + return oss_.str(); + } + + private: + void init(); + + std::ostream& os() { + return printer_->os(); + } + + std::ostringstream oss_; + std::unique_ptr printer_; + std::unique_ptr var_name_rewriter_; }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/cpp_intrinsics.h b/torch/csrc/jit/tensorexpr/cpp_intrinsics.h new file mode 100644 index 0000000000000..caeeed693ff38 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cpp_intrinsics.h @@ -0,0 +1,36 @@ +#pragma once + +namespace torch { +namespace jit { +namespace tensorexpr { + +constexpr auto cpp_intrinsics_definition = R"( +namespace std { + +template ::value, int>::type = 0> +T rsqrt(T v) { + return 1.0f / std::sqrt(v); +} + +template ::value, int>::type = 0> +T frac(T v) { + T intpart; + return std::modf(v, &intpart); +} + +template +To bitcast(const From& v) { + assert(sizeof(To) == sizeof(From)); + To res; + std::memcpy(&res, &v, sizeof(From)); + return res; +} + +} // namespace std +)"; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 2d00b1e4ab481..c23eda31204de 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -45,18 +45,9 @@ class ScopedVarName { VarPtr var_ = nullptr; }; -static int as_int(ExprPtr expr) { - auto v = to(expr); - if (!v) { - throw malformed_input( - "cuda_codegen: non Int expr interpreted as int", expr); - } - - return v->value(); -} - static bool is_zero(ExprPtr expr) { - return as_int(expr) == 0; + auto v = intValue(expr); + return v && *v == 0; } static const at::cuda::NVRTC& nvrtc() { @@ -107,6 +98,8 @@ std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) { return "bool"; case ScalarType::Half: return "half"; + case ScalarType::BFloat16: + return "__nv_bfloat16"; case ScalarType::Char: return "char"; case ScalarType::Byte: @@ -222,11 +215,11 @@ void CudaPrinter::print_flat_alloc(AllocatePtr alloc) { // TODO: this should be merged with the storage flattener. int64_t flat_size = 1; for (auto dim : dims) { - IntImmPtr dim_i = to(dim); + auto dim_i = intValue(dim); if (dim_i) { - flat_size *= dim_i->value(); + flat_size *= *dim_i; } else { - throw std::runtime_error("Only IntImm dimensions are supported for now"); + throw std::runtime_error("Only integer dimensions are supported for now"); } } os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var()) @@ -260,20 +253,15 @@ void CudaPrinter::visit(ForPtr v) { } void CudaPrinter::visit(CastPtr v) { - if (v->dtype().scalar_type() == ScalarType::Half) { - os() << "__float2half("; - v->src_value()->accept(this); - os() << ")"; - return; - } else if (v->src_value()->dtype().scalar_type() == ScalarType::Half) { - os() << "__half2float("; - v->src_value()->accept(this); - os() << ")"; - return; - } - - os() << "(" << dtypeToCppString(v->dtype()) << ")"; - os() << "("; + std::string castFn = v->dtype().scalar_type() == ScalarType::Half + ? "__float2half" + : v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16" + : v->src_value()->dtype().scalar_type() == ScalarType::Half + ? "__half2float" + : v->src_value()->dtype().scalar_type() == ScalarType::BFloat16 + ? "__bfloat162float" + : ("(" + dtypeToCppString(v->dtype()) + ")"); + os() << castFn << "("; v->src_value()->accept(this); os() << ")"; } @@ -329,7 +317,8 @@ void CudaPrinter::visit(LoadPtr v) { return; } if (v->dtype().scalar_type() == ScalarType::Bool || - v->dtype().scalar_type() == ScalarType::Half) { + v->dtype().scalar_type() == ScalarType::Half || + v->dtype().scalar_type() == ScalarType::BFloat16) { // There's no __ldg overload for bool or half. os() << *v->base_handle() << "[" << *v->flat_index() << "]"; return; @@ -389,34 +378,33 @@ class AtomicAddFuser : public IRMutator { StmtPtr mutate(StorePtr v) override { BufPtr buf = v->buf(); - StorePtr orig = const_cast(v); // NOLINT // Thread locals never need to be atomic. if (thread_local_bufs_.count(buf->base_handle()) != 0) { - return orig; + return v; } ScalarType dtype = v->value()->dtype().scalar_type(); if (dtype != ScalarType::Float && dtype != ScalarType::Double) { - return orig; + return v; } AddPtr add_v = to(v->value()); if (!add_v) { - return orig; + return v; } LoadPtr load_v = to(add_v->lhs()); if (!load_v) { - return orig; + return v; } if (v->base_handle() != load_v->base_handle()) { - return orig; + return v; } if (v->indices().empty() && load_v->indices().empty()) { - return orig; + return v; } bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index()); if (!index_equal) { - return orig; + return v; } // TODO: this checks that the metavars occur directly as an index, but this @@ -431,7 +419,7 @@ class AtomicAddFuser : public IRMutator { if (vars_to_find.empty()) { // All metavars accounted for. - return orig; + return v; } return alloc(buf, v->indices(), add_v->rhs()); @@ -609,23 +597,21 @@ class PrioritizeLoad : public IRMutator { } StmtPtr mutate(BlockPtr v) override { - BlockPtr v1 = const_cast(v); // NOLINT - assert(v1); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::list stmts = v1->stmts(); + std::list stmts = v->stmts(); for (StmtPtr stmt : stmts) { PushList(); StmtPtr stmt_new = stmt->accept_mutator(this); - AddMemLoadsFromList(v1, stmt); + AddMemLoadsFromList(v, stmt); PopList(); if (stmt_new == stmt) { continue; } - v1->replace_stmt(stmt, stmt_new); + v->replace_stmt(stmt, stmt_new); } - return v1; + return v; } ExprPtr mutate(IfThenElsePtr v) override { @@ -821,7 +807,7 @@ StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector stmts; for (auto& v : innerSegments) { - for (auto* s : v.stmts()) { + for (auto s : v.stmts()) { stmts.push_back(s); } } @@ -956,6 +942,9 @@ void CudaCodeGen::Initialize() { if (halfChecker.hasHalf()) { os() << fuser::cuda::half_support_literal << std::endl; } + if (halfChecker.hasBFloat16()) { + os() << fuser::cuda::bfloat16_support_literal << std::endl; + } std::string func_name = GetUniqueFuncName(kernel_func_name()); os() << "extern \"C\" __global__" << std::endl; diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index c7a28bdbb23ac..4582433d95697 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -10,6 +10,17 @@ namespace tensorexpr { RegisterCodeGen ir_eval_codegen_reg("simple_ir_eval"); +int64_t Value::intValue() const { +#define TYPE_CASE(Type, Name) \ + if (dtype_ == k##Name) { \ + return int64_t{Name##values[0]}; \ + } + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + return 0; +} + template inline typename std::enable_if::value, T>::type mod_value( T lhs, @@ -51,6 +62,10 @@ inline c10::Half div_value(c10::Half lhs, c10::Half rhs) { return lhs / rhs; } +inline c10::BFloat16 div_value(c10::BFloat16 lhs, c10::BFloat16 rhs) { + return lhs / rhs; +} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class SimpleIREvaluatorImpl : public IRVisitor { public: @@ -281,8 +296,12 @@ class SimpleIREvaluatorImpl : public IRVisitor { return Value(result_v); } - template - void visit_binary_op(BinaryOpNode* v, bool option = false) { + template < + typename D, + typename std::enable_if())), + void>::value>::type* = nullptr> + void visit_binary_op(NodePtr v, bool option = false) { v->lhs()->accept(this); Value lhs_v = value_; v->rhs()->accept(this); @@ -332,7 +351,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case ScalarType::Name: \ value_ = binary_op(lhs_v, rhs_v, expr_type); \ break; - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE case ScalarType::Bool: value_ = binary_op(lhs_v, rhs_v, expr_type); @@ -355,7 +374,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case ScalarType::Name: \ value = compare_select_op(lhs, rhs, retval1, retval2, cmp_op); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -387,7 +406,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { value_ = compare_select_op_helper( \ lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -398,7 +417,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { TORCH_API void visit(Name##ImmPtr v) override { \ value_ = Value(v->value()); \ } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT); #undef IMM_VISIT TORCH_API void visit(BlockPtr v) override { @@ -449,7 +468,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case ScalarType::Name: \ this->value_ = Value(castValues(src_dtype, v)); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DST_TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DST_TYPE_CASE); #undef DST_TYPE_CASE default: throw unsupported_dtype(); @@ -471,7 +490,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { case ScalarType::Name: \ doCastFromSrc(src_dtype, dst_dtype, value_); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, SRC_TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, SRC_TYPE_CASE); #undef SRC_TYPE_CASE default: throw unsupported_dtype(); @@ -533,15 +552,16 @@ class SimpleIREvaluatorImpl : public IRVisitor { TORCH_API void visit(ForPtr v) override { ExprPtr var_node = v->var(); v->start()->accept(this); - int start = value_.as(); + auto dtype = value_.dtype(); + auto start = value_.intValue(); v->stop()->accept(this); - int stop = value_.as(); + auto stop = value_.intValue(); if (eval_context_.count(var_node)) { throw malformed_input("could not find var_node in For context", v); } - for (int i = start; i < stop; i++) { - eval_context_[var_node] = Value(i); + for (auto i = start; i < stop; i++) { + eval_context_[var_node] = Value(dtype, i); if (v->body()) { v->body()->accept(this); } @@ -551,9 +571,9 @@ class SimpleIREvaluatorImpl : public IRVisitor { TORCH_API void visit(RampPtr v) override { v->base()->accept(this); - int base = value().as(); + auto base = value().intValue(); v->stride()->accept(this); - int stride = value().as(); + auto stride = value().intValue(); int lanes = v->lanes(); std::vector values(lanes); @@ -574,7 +594,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { std::vector v(lanes, value.as()); \ value_ = Value(v); \ } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -594,6 +614,9 @@ class SimpleIREvaluatorImpl : public IRVisitor { #undef TYPE_CASE case ScalarType::Half: throw unsupported_dtype("IfThenElse condition can't have Half dtype"); + case ScalarType::BFloat16: + throw unsupported_dtype( + "IfThenElse condition can't have BFloat16 dtype"); default: throw unsupported_dtype(); } @@ -605,6 +628,24 @@ class SimpleIREvaluatorImpl : public IRVisitor { } } + template + std::vector toLongVec(T&& t) { + return std::vector{std::begin(t), std::end(t)}; + } + + std::vector indexVec(const Value& v) { + switch (v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + return toLongVec(v.as_vec()); + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + default: + throw unsupported_dtype(); + } + return {}; + } + TORCH_API void visit(LoadPtr v) override { auto iter = buffer_mapping_.find(v->buf()); if (iter == buffer_mapping_.end()) { @@ -614,7 +655,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); - std::vector index = value().as_vec(); + auto index = indexVec(value()); ScalarType v_sdtype = v->dtype().scalar_type(); switch (v_sdtype) { #define TYPE_CASE(Type, Name) \ @@ -626,7 +667,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { } \ value_ = Value(v); \ } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -643,7 +684,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { ExprPtr flat_idx = flatten_index(v->buf()->dims(), v->indices()); flat_idx->accept(this); - std::vector index = value().as_vec(); + auto index = indexVec(value()); ScalarType v_sdtype = v->value()->dtype().scalar_type(); switch (v_sdtype) { @@ -659,7 +700,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { ptr##Name[index[i]] = value[i]; \ } \ } break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -692,7 +733,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { buf_dtypes.push_back((int8_t)b->dtype().scalar_type()); for (ExprPtr dim_expr : b->dims()) { dim_expr->accept(this); - buf_dims.push_back(value().as()); + buf_dims.push_back(value().intValue()); } } for (ExprPtr a : v->args()) { @@ -702,7 +743,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { if (value().dtype() == kLong) { val = value().as(); } else if (value().dtype() == kInt) { - val = value().as(); + val = value().intValue(); } else { throw malformed_input( "extra_args in ExternalCalls must have int64 dtype", v); @@ -767,6 +808,8 @@ class SimpleIREvaluatorImpl : public IRVisitor { visit_intrinsics_helper(v); } else if (inp_dtype == ScalarType::Half) { throw unsupported_dtype(); // TODO + } else if (inp_dtype == ScalarType::BFloat16) { + throw unsupported_dtype(); // TODO } } else { switch (ty) { @@ -785,10 +828,10 @@ class SimpleIREvaluatorImpl : public IRVisitor { void visit(AllocatePtr v) override { BufPtr b = v->buf(); std::vector dims = b->dims(); - int total_byte_size = b->dtype().byte_size(); + int64_t total_byte_size = b->dtype().byte_size(); for (auto& dim : dims) { dim->accept(this); - total_byte_size *= value_.as(); + total_byte_size *= value_.intValue(); } auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int); std::unique_ptr> buffer(new std::vector(int_count)); @@ -820,7 +863,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { void visit(CondPtr v) override { v->condition()->accept(this); - if (value().as()) { + if (value().intValue()) { if (v->true_stmt()) { v->true_stmt()->accept(this); } @@ -1005,7 +1048,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) { impl_->bindVar(bufArg.var(), typed_data); \ break; \ } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h index 38ec99bd431cf..e11bb169484f6 100644 --- a/torch/csrc/jit/tensorexpr/eval.h +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -29,19 +29,31 @@ class Value { Intvalues.push_back(0); } + template + Value(Dtype dtype, T v) : dtype_(dtype) { +#define TYPE_CASE(Type, Name) \ + if (dtype == k##Name) { \ + Name##values.push_back(v); \ + return; \ + } + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); +#undef TYPE_CASE + throw unsupported_dtype(); + } + #define VALUE_CTOR(Type, Name) \ Value(Type v) : dtype_(k##Name) { \ Name##values.push_back(v); \ } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_CTOR); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR); #undef VALUE_CTOR #define VALUE_VEC_CTOR(Type, Name) \ Value(const std::vector& v) \ : dtype_(Dtype(k##Name, v.size())), Name##values(v) {} // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_VEC_CTOR); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR); #undef VALUE_VEC_CTOR template @@ -50,6 +62,8 @@ class Value { template const std::vector& as_vec() const; + int64_t intValue() const; + Dtype dtype() const { return dtype_; } @@ -58,7 +72,7 @@ class Value { Dtype dtype_; #define VALUE_STORAGE(Type, Name) std::vector Name##values; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_STORAGE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE); #undef VALUE_STORAGE void* ptr; }; @@ -71,7 +85,7 @@ class Value { } \ return Name##values[0]; \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH); #undef VALUE_AS_DISPATCH #define VALUE_AS_VEC_DISPATCH(Type, Name) \ @@ -82,7 +96,7 @@ AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_DISPATCH); } \ return Name##values; \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, VALUE_AS_VEC_DISPATCH); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH); #undef VALUE_AS_VEC_DISPATCH template @@ -192,7 +206,7 @@ class ExprEval { ret_value_ = Value(ret_val_arg[0]); \ } break; // NOLINTNEXTLINE(modernize-use-emplace) - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE case ScalarType::Bool: { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -217,7 +231,7 @@ class ExprEval { codegen_->call_raw(args_extended); \ ret_value_ = Value(ret_val_arg[0]); \ } break; - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE case ScalarType::Bool: { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) diff --git a/torch/csrc/jit/tensorexpr/exceptions.h b/torch/csrc/jit/tensorexpr/exceptions.h index cf23bbc2289c4..7194dfe166aa8 100644 --- a/torch/csrc/jit/tensorexpr/exceptions.h +++ b/torch/csrc/jit/tensorexpr/exceptions.h @@ -84,6 +84,8 @@ class malformed_ir : public std::runtime_error { "MALFORMED IR: " + err + " - " + std::to_string(stmt)) {} }; +TORCH_API std::string buildErrorMessage(const std::string& s); + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index cbf5ddd9f1d6d..c757d4b0ca201 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -89,7 +89,7 @@ ExprHandle ExprHandle::operator>>(const ExprHandle& other) const { // NOLINTNEXTLINE #define IMM_EXPR_DECLARE(Type, Name) \ ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {} -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE); #undef IMM_EXPR_DECLARE ExprHandle sin(const ExprHandle& v) { diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index fae24ec34be28..41ce99a085179 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -8,7 +8,6 @@ #include #include #include -#include #include namespace torch { @@ -36,10 +35,11 @@ enum IRNodeType { }; // The common base between all expression node. -class TORCH_API Expr : public KernelScopedObject { +class TORCH_API Expr : public std::enable_shared_from_this { public: explicit Expr(Dtype dtype, IRNodeType expr_type = kOther) : dtype_(dtype), expr_type_(expr_type) {} + virtual ~Expr() = default; Dtype dtype() const { return dtype_; } @@ -66,6 +66,11 @@ class TORCH_API Expr : public KernelScopedObject { */ static ExprPtr clone(ExprPtr s); + protected: + std::shared_ptr getptr() { + return shared_from_this(); + } + private: Dtype dtype_; IRNodeType expr_type_; @@ -78,7 +83,7 @@ class ExprNode : public Base { public: using ExprNodeBase = ExprNode; void accept(IRVisitor* visitor) override { - visitor->visit(static_to(this)); + visitor->visit(static_to(Base::getptr())); } ExprPtr accept_mutator(IRMutator* mutator) override; // pass the constructor to the base class @@ -105,7 +110,7 @@ class TORCH_API ExprHandle { } #define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE); #undef IMM_EXPR_DECLARE template @@ -164,8 +169,12 @@ class TORCH_API Var : public ExprNode { return name_hint_; } - void set_name_hint(const std::string& name_hint) { - name_hint_ = name_hint; + void set_name_hint(const std::string& name) { + name_hint_ = name; + } + + void set_name_hint(std::string&& name) { + name_hint_ = name; } Var(std::string name_hint, Dtype dtype) @@ -310,11 +319,16 @@ class TORCH_API BufHandle : public ExprHandle { // object. For example: VarHandle x('x'); ExprHandle x2 = x; class TORCH_API VarHandle : public ExprHandle { public: - VarHandle() : ExprHandle(nullptr) {} + // Creates an empty VarHandle whose base Var is set to nullptr. + VarHandle() : ExprHandle() {} + explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {} + VarHandle(const std::string& name_hint, Dtype dtype) : ExprHandle(Var::make(name_hint, dtype)) {} + explicit VarHandle(VarPtr node) : ExprHandle(node) {} + VarPtr node() const { return static_to(ExprHandle::node()); } @@ -335,7 +349,7 @@ class TORCH_API VarHandle : public ExprHandle { template ExprPtr ExprNode::accept_mutator(IRMutator* mutator) { - return mutator->mutate(static_to(this)); + return mutator->mutate(static_to(Base::getptr())); } inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) { diff --git a/torch/csrc/jit/tensorexpr/fwd_decls.h b/torch/csrc/jit/tensorexpr/fwd_decls.h index 01a767067f620..119308b053442 100644 --- a/torch/csrc/jit/tensorexpr/fwd_decls.h +++ b/torch/csrc/jit/tensorexpr/fwd_decls.h @@ -1,26 +1,27 @@ #pragma once #include +#include namespace torch { namespace jit { namespace tensorexpr { template -using NodePtr = Node*; +using NodePtr = std::shared_ptr; template NodePtr to(NodePtr x) { - return dynamic_cast>(x); + return std::dynamic_pointer_cast(x); } template NodePtr static_to(NodePtr x) { - return static_cast>(x); + return std::static_pointer_cast(x); } template NodePtr alloc(Args&&... args) { - return new Node(std::forward(args)...); + return std::make_shared(std::forward(args)...); } class Buf; @@ -112,7 +113,7 @@ using SyncThreadsPtr = NodePtr; #define IMM_DECLARE(Type, Name) \ class Name##Imm; \ using Name##ImmPtr = NodePtr; -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE); #undef IMM_DECLARE } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index 67f9a671bfa20..d55ea0559e5e1 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -26,14 +26,21 @@ Node* moveCatAfterUse(Node* cat, Node* user, std::shared_ptr subgraph) { // %4 = aten::cat(%3, ...) // return (%4) - TORCH_INTERNAL_ASSERT(cat->output()->hasUses()); - TORCH_INTERNAL_ASSERT(cat->output()->uses().size() == 1); - TORCH_INTERNAL_ASSERT(cat->input(0)->node()->kind() == prim::ListConstruct); + TORCH_INTERNAL_ASSERT( + cat->output()->hasUses(), + buildErrorMessage("aten::cat output is not used.")); + TORCH_INTERNAL_ASSERT( + cat->output()->uses().size() == 1, + buildErrorMessage("aten::cat output is used in multiple places.")); + TORCH_INTERNAL_ASSERT( + cat->input(0)->node()->kind() == prim::ListConstruct, + buildErrorMessage("aten::cat inputs are not expected.")); auto cat_list = cat->input(0)->node(); auto cat_inputs = cat_list->inputs(); auto user_tensor_type = user->output()->type()->cast(); - TORCH_INTERNAL_ASSERT(user_tensor_type); + TORCH_INTERNAL_ASSERT( + user_tensor_type, buildErrorMessage("Unexpected user tensor type")); std::unordered_map new_cat_inputs; for (auto inp : cat_inputs) { auto new_cat_input = subgraph->createClone( @@ -41,7 +48,8 @@ Node* moveCatAfterUse(Node* cat, Node* user, std::shared_ptr subgraph) { // Since we are cloning user, its result should be the same scalar type // as the user. But the dims should correspond to that of the input. auto input_tensor_type = inp->type()->cast(); - TORCH_INTERNAL_ASSERT(input_tensor_type); + TORCH_INTERNAL_ASSERT( + input_tensor_type, buildErrorMessage("Unexpected input tensor type")); auto new_input_type = input_tensor_type->withScalarType(user_tensor_type->scalarType()); new_cat_input->output()->setType(new_input_type); @@ -60,7 +68,9 @@ Node* moveCatAfterUse(Node* cat, Node* user, std::shared_ptr subgraph) { user->output()->replaceAllUsesWith(new_cat->output()); user->destroy(); - TORCH_INTERNAL_ASSERT(!cat->output()->hasUses()); + TORCH_INTERNAL_ASSERT( + !cat->output()->hasUses(), + buildErrorMessage("aten::cat output is not used.")); cat->destroy(); if (!cat_list->output()->hasUses()) { @@ -84,10 +94,15 @@ int numTensorInputs(Node* node) { // If the inputs to `cat` are of different types, then the implementation // of `cat` is expected to promote type. bool doesCatPromoteTypes(Node* node) { - TORCH_INTERNAL_ASSERT(node->kind() == aten::cat); - TORCH_INTERNAL_ASSERT(node->input(0)->node()->kind() == prim::ListConstruct); + TORCH_INTERNAL_ASSERT( + node->kind() == aten::cat, + buildErrorMessage("Graph node is not aten::cat.")); + TORCH_INTERNAL_ASSERT( + node->input(0)->node()->kind() == prim::ListConstruct, + buildErrorMessage("aten::cat inputs are not expected.")); auto inputs = node->input(0)->node()->inputs(); - TORCH_INTERNAL_ASSERT(!inputs.empty()); + TORCH_INTERNAL_ASSERT( + !inputs.empty(), buildErrorMessage("Empty inputs of ListConstruct")); auto scalar_type = inputs.front()->type()->cast()->scalarType(); for (size_t i = 1; i < inputs.size(); ++i) { @@ -122,14 +137,18 @@ bool doesCatPromoteTypes(Node* node) { // it user needs to reflect the original type. This is currently not // handled. TODO void moveCatOpToEnd(Node* cat, std::shared_ptr subgraph) { - TORCH_INTERNAL_ASSERT(cat->kind() == aten::cat); + TORCH_INTERNAL_ASSERT( + cat->kind() == aten::cat, + buildErrorMessage("Graph node is not aten::cat.")); if (cat->output()->uses().size() == 1) { auto use = cat->output()->uses().front(); if (use.user->isMemberOf(supported_eltwise_set()) && numTensorInputs(use.user) == 1) { if (!doesCatPromoteTypes(cat)) { TORCH_INTERNAL_ASSERT( - use.user->output()->owningGraph() == subgraph.get()); + use.user->output()->owningGraph() == subgraph.get(), + buildErrorMessage( + "aten::cat user graph does not math the given subgraph.")); auto new_cat = moveCatAfterUse(cat, use.user, subgraph); moveCatOpToEnd(new_cat, subgraph); } diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h index 15d48cd8952e0..8ecf956d6d75b 100644 --- a/torch/csrc/jit/tensorexpr/half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -18,17 +18,23 @@ class HalfChecker : public IRVisitor { } } - bool hasHalf() { + bool hasHalf() const { return hasHalf_; } + bool hasBFloat16() const { + return hasBFloat16_; + } + void visit(LoadPtr v) override { hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16; IRVisitor::visit(v); } void visit(StorePtr v) override { hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half; + hasBFloat16_ |= v->buf()->dtype().scalar_type() == ScalarType::BFloat16; IRVisitor::visit(v); } @@ -36,20 +42,26 @@ class HalfChecker : public IRVisitor { hasHalf_ = true; } + void visit(BFloat16ImmPtr v) override { + hasBFloat16_ = true; + } + void visit(CastPtr v) override { hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16; IRVisitor::visit(v); } private: bool hasHalf_{false}; + bool hasBFloat16_{false}; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class HalfRewriter : public IRMutator { ExprPtr mutate(LoadPtr v) override { ExprPtr child = IRMutator::mutate(v); - if (child->dtype().scalar_type() != ScalarType::Half) { + if (!isHalf(child)) { return child; } @@ -63,27 +75,31 @@ class HalfRewriter : public IRMutator { StmtPtr mutate(StorePtr v) override { // Since mutation changes the `value()` expression in-place, we need to // get the dtype of the `value()` before that is mutated. - Dtype newType = v->value()->dtype(); + auto newType = v->value()->dtype(); ExprPtr new_val = v->value()->accept_mutator(this); - if (newType.scalar_type() == ScalarType::Half) { - new_val = - alloc(newType.cloneWithScalarType(ScalarType::Half), new_val); + if (isHalf(newType.scalar_type())) { + new_val = alloc(newType, new_val); inserted_half_casts_.insert(new_val); } - return alloc(v->buf(), v->indices(), new_val); + v->set_value(new_val); + return v; } ExprPtr mutate(HalfImmPtr v) override { return alloc(kFloat, v); } + ExprPtr mutate(BFloat16ImmPtr v) override { + return alloc(kFloat, v); + } + ExprPtr mutate(CastPtr v) override { ExprPtr child = v->src_value()->accept_mutator(this); // just don't allow half casts we didn't insert. - if (v->dtype().scalar_type() == ScalarType::Half) { + if (isHalf(v)) { if (inserted_half_casts_.count(v) < 1) { return child; } @@ -104,8 +120,9 @@ class HalfRewriter : public IRMutator { return alloc(v->dtype(), child); } + StmtPtr mutate(LetPtr v) override { - if (v->dtype().scalar_type() == ScalarType::Half) { + if (isHalf(v->dtype().scalar_type())) { VarPtr load_new_var = alloc(v->var()->name_hint(), kFloat); ExprPtr new_value = alloc( v->dtype().cloneWithScalarType(ScalarType::Float), @@ -127,7 +144,55 @@ class HalfRewriter : public IRMutator { return v; } + template + ExprPtr mutateArithmetic(T v) { + IRMutator::mutate(v); + if (isHalf(v)) { + v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat)); + } + return v; + } + + ExprPtr mutate(AddPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(SubPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MulPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(DivPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MaxPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MinPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(CompareSelectPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(BroadcastPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(IfThenElsePtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(IntrinsicsPtr v) override { + return mutateArithmetic(v); + } + private: + static bool isHalf(ScalarType st) { + return st == ScalarType::Half || st == ScalarType::BFloat16; + } + + static bool isHalf(ExprPtr v) { + return isHalf(v->dtype().scalar_type()); + } + std::unordered_set inserted_half_casts_; std::unordered_map var_map; }; diff --git a/torch/csrc/jit/tensorexpr/hash_provider.cpp b/torch/csrc/jit/tensorexpr/hash_provider.cpp index fbc257d1988df..dce25669bf323 100644 --- a/torch/csrc/jit/tensorexpr/hash_provider.cpp +++ b/torch/csrc/jit/tensorexpr/hash_provider.cpp @@ -63,6 +63,13 @@ void HashProvider::visit(ModPtr v) { putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs()))); } +void HashProvider::visit(RoundOffPtr v) { + CACHE_GUARD(); + v->lhs()->accept(this); + v->rhs()->accept(this); + putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs()))); +} + void HashProvider::visit(MaxPtr v) { CACHE_GUARD(); v->lhs()->accept(this); diff --git a/torch/csrc/jit/tensorexpr/hash_provider.h b/torch/csrc/jit/tensorexpr/hash_provider.h index 5a33f048fec84..35d493a0025b4 100644 --- a/torch/csrc/jit/tensorexpr/hash_provider.h +++ b/torch/csrc/jit/tensorexpr/hash_provider.h @@ -59,12 +59,16 @@ class TORCH_API HashProvider : public IRVisitor { return hashOf(e); } - bool cachedHash(const KernelScopedObject* e) { + bool cachedHash(ExprPtr e) { return exprToHash_.find(e) != exprToHash_.end(); } + bool cachedHash(StmtPtr s) { + return stmtToHash_.find(s) != stmtToHash_.end(); + } void clearCache() { exprToHash_.clear(); + stmtToHash_.clear(); } void visit(AddPtr v) override; @@ -72,6 +76,7 @@ class TORCH_API HashProvider : public IRVisitor { void visit(MulPtr v) override; void visit(DivPtr v) override; void visit(ModPtr v) override; + void visit(RoundOffPtr v) override; void visit(MaxPtr v) override; void visit(MinPtr v) override; void visit(AndPtr v) override; @@ -87,7 +92,7 @@ class TORCH_API HashProvider : public IRVisitor { CACHE_GUARD(); \ putHash(v, hash_combine(#Name, v->value())); \ } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT); #undef IMM_VISIT void visit(CastPtr v) override; @@ -133,8 +138,8 @@ class TORCH_API HashProvider : public IRVisitor { } SimplifierHashType hashOf(StmtPtr s) { - auto it = exprToHash_.find(s); - if (it != exprToHash_.end()) { + auto it = stmtToHash_.find(s); + if (it != stmtToHash_.end()) { return it->second; } @@ -182,15 +187,23 @@ class TORCH_API HashProvider : public IRVisitor { _hash_combine(seed, args...); } - void putHash(const KernelScopedObject* e, SimplifierHashType h) { + void putHash(ExprPtr e, SimplifierHashType h) { auto res = exprToHash_.emplace(e, h); if (res.second == false) { // This is always a logic bug since we should check the cache first. throw std::runtime_error("hash collision"); } } + void putHash(StmtPtr s, SimplifierHashType h) { + auto res = stmtToHash_.emplace(s, h); + if (res.second == false) { + // This is always a logic bug since we should check the cache first. + throw std::runtime_error("hash collision"); + } + } - std::unordered_map exprToHash_; + std::unordered_map exprToHash_; + std::unordered_map stmtToHash_; UniqueNameManager name_manager_; size_t te_hash(SimplifierHashType val) { @@ -274,6 +287,14 @@ class TORCH_API HashProvider : public IRVisitor { std::memcpy(&n, &d, sizeof d); return te_hash(n); } + + size_t te_hash(at::BFloat16 d) { + // memcpy as type punning. Should be optimized out. + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int16_t n; + std::memcpy(&n, &d, sizeof d); + return te_hash(n); + } }; } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index f66c0c5ba0701..439993c481903 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -88,17 +88,17 @@ ExprPtr flatten_index( throw malformed_input("dimensions mismatch in flatten_index"); } if (ndim == 0) { - return alloc(0); + return alloc(0); } std::vector strides(ndim); // stride[i] = stride[i+1]*dims[i+1], i < ndim-1 // stride[i] = 1, i = ndim-1 - strides[ndim - 1] = alloc(1); + strides[ndim - 1] = immLike(dims[ndim - 1], 1); for (size_t i = 1; i < ndim; i++) { strides[ndim - 1 - i] = alloc(strides[ndim - i], dims[ndim - i]); } - ExprPtr total_index = alloc(0); + ExprPtr total_index = immLike(indices[0], 0); for (const auto i : c10::irange(ndim)) { total_index = alloc(total_index, alloc(indices[i], strides[i])); } @@ -231,7 +231,7 @@ bool immediateIsNegative(ExprPtr e) { if (Name##ImmPtr imm = to(e)) { \ return imm->value() < 0; \ } - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE return false; } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 761b233fe8375..65a362ef023fe 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -178,6 +178,12 @@ class BinaryOpNode : public ExprNode { ExprPtr rhs_; }; +namespace detail { +template +void bin_op_deducer(BinaryOpNode); +bool bin_op_deducer(...); +} // namespace detail + class TORCH_API Add : public BinaryOpNode { public: Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {} @@ -314,7 +320,7 @@ class Min : public BinaryOpNode { private: \ Type value_; \ }; -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE); #undef IMM_DECLARE // Get immediate by ScalarType. @@ -323,9 +329,9 @@ ExprPtr getImmediateByType(ScalarType immType, T initialVal) { switch (immType) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ - return alloc(initialVal); + return alloc(Type(initialVal)); // NOLINTNEXTLINE(bugprone-branch-clone) - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -338,13 +344,37 @@ ExprPtr getImmediateByType(Dtype dtype, T initialVal) { return getImmediateByType(dtype.scalar_type(), initialVal); } +template +ExprPtr immLike(ExprPtr e, T v) { + return getImmediateByType(e->dtype(), v); +} + +template +ExprPtr immLike(ExprHandle e, T v) { + return immLike(e.node(), v); +} + +inline c10::optional intValue(ExprPtr e) { +#define TYPE_CASE(Type, Name) \ + if (auto v = to(e)) { \ + return v->value(); \ + } + AT_FORALL_INT_TYPES(TYPE_CASE); +#undef TYPE_CASE + return c10::nullopt; +} + +inline c10::optional intValue(ExprHandle e) { + return intValue(e.node()); +} + template T immediateAs(ExprPtr e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value(); \ } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE throw unsupported_dtype(); return 0; @@ -361,7 +391,7 @@ bool immediateEquals(ExprPtr e, T val) { if (Name##ImmPtr imm = to(e)) { \ return imm->value() == val; \ } - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE throw unsupported_dtype(); return false; @@ -678,6 +708,7 @@ enum IntrinsicsOp { kFrac, kIsNan, kRand, // We need more discussions on this. Should we consider stateful? + kMaxIntrinsicsOp, }; class TORCH_API Intrinsics : public ExprNode { @@ -858,8 +889,9 @@ class TORCH_API Intrinsics : public ExprNode { params_ = std::move(params); } - private: static int OpArgCount(IntrinsicsOp op_type); + + private: static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1); static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2); static Dtype IntrinsicsDtype( diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.cpp b/torch/csrc/jit/tensorexpr/ir_cloner.cpp index f724f2cbeb16f..1144833c7990e 100644 --- a/torch/csrc/jit/tensorexpr/ir_cloner.cpp +++ b/torch/csrc/jit/tensorexpr/ir_cloner.cpp @@ -10,9 +10,13 @@ namespace torch { namespace jit { namespace tensorexpr { -template +template < + typename Op, + typename std::enable_if())), + void>::value>::type* = nullptr> static ExprPtr mutate_binary_op( - NodePtr> v, + NodePtr v, IRCloner* cloner, bool option = false) { ExprPtr lhs_new = v->lhs()->accept_mutator(cloner); @@ -115,7 +119,7 @@ ExprPtr IRCloner::mutate(CompareSelectPtr v) { ExprPtr IRCloner::mutate(Name##ImmPtr v) { \ return v; \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE); #undef IMM_MUTATE_DEFINE ExprPtr IRCloner::mutate(CastPtr v) { diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.h b/torch/csrc/jit/tensorexpr/ir_cloner.h index f03e12886eabe..5f516a02ffadb 100644 --- a/torch/csrc/jit/tensorexpr/ir_cloner.h +++ b/torch/csrc/jit/tensorexpr/ir_cloner.h @@ -26,7 +26,7 @@ class TORCH_API IRCloner : public IRMutator { ExprPtr mutate(RshiftPtr v) override; ExprPtr mutate(CompareSelectPtr v) override; #define IMM_MUTATE_DECLARE(Type, Name) ExprPtr mutate(Name##ImmPtr v) override; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE ExprPtr mutate(CastPtr v) override; ExprPtr mutate(BitCastPtr v) override; diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp index 96635acab8c90..e2e9c46e133a5 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.cpp +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -11,9 +11,13 @@ namespace torch { namespace jit { namespace tensorexpr { -template +template < + typename Op, + typename std::enable_if())), + void>::value>::type* = nullptr> static ExprPtr mutate_binary_op( - BinaryOpNode* v, + NodePtr v, IRMutator* mutator, bool option = false) { ExprPtr lhs = v->lhs(); @@ -111,7 +115,7 @@ ExprPtr IRMutator::mutate(CompareSelectPtr v) { ExprPtr IRMutator::mutate(Name##ImmPtr v) { \ return v; \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DEFINE); #undef IMM_MUTATE_DEFINE ExprPtr IRMutator::mutate(CastPtr v) { @@ -420,14 +424,16 @@ StmtPtr IRMutator::mutate(SyncThreadsPtr v) { StmtPtr IRMutator::mutate(ExternalCallPtr v) { BufPtr buf = v->buf(); BufPtr buf_new = to(buf->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_new); + TORCH_INTERNAL_ASSERT( + buf_new, buildErrorMessage("IRMutator produced null for Buf.")); bool buf_args_changed = false; std::vector buf_args_new; buf_args_new.reserve(v->buf_args().size()); for (BufPtr buf_arg : v->buf_args()) { BufPtr buf_arg_new = to(buf_arg->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_arg_new); + TORCH_INTERNAL_ASSERT( + buf_arg_new, buildErrorMessage("IRMutator produced null for Buf.")); buf_args_new.push_back(buf_arg_new); buf_args_changed |= buf_arg_new != buf_arg; } @@ -456,7 +462,8 @@ StmtPtr IRMutator::mutate(ExternalCallPtr v) { StmtPtr IRMutator::mutate(AllocatePtr v) { BufPtr buf = v->buf(); BufPtr buf_new = to(buf->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_new); + TORCH_INTERNAL_ASSERT( + buf_new, buildErrorMessage("IRMutator produced null for Buf.")); if (buf != buf_new) { v->set_buf(buf_new); } @@ -466,7 +473,8 @@ StmtPtr IRMutator::mutate(AllocatePtr v) { StmtPtr IRMutator::mutate(FreePtr v) { BufPtr buf = v->buf(); BufPtr buf_new = to(buf->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_new); + TORCH_INTERNAL_ASSERT( + buf_new, buildErrorMessage("IRMutator produced null for Buf.")); if (buf != buf_new) { v->set_buf(buf_new); } diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h index fb6c420af46a0..0a96876606dfb 100644 --- a/torch/csrc/jit/tensorexpr/ir_mutator.h +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -25,7 +25,7 @@ class TORCH_API IRMutator { virtual ExprPtr mutate(RshiftPtr v); virtual ExprPtr mutate(CompareSelectPtr v); #define IMM_MUTATE_DECLARE(Type, Name) virtual ExprPtr mutate(Name##ImmPtr v); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE); #undef IMM_MUTATE_DECLARE virtual ExprPtr mutate(CastPtr v); virtual ExprPtr mutate(BitCastPtr v); diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 23466f39160c8..4a10c282e60b1 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -25,12 +25,34 @@ void IRPrinter::print(Expr& expr) { void IRPrinter::print(Stmt& stmt) { stmt.accept(this); } +std::string IRPrinter::to_string(CompareSelectOperation op) { + switch (op) { + case CompareSelectOperation::kEQ: + return "=="; + case CompareSelectOperation::kNE: + return "!="; + case CompareSelectOperation::kGT: + return ">"; + case CompareSelectOperation::kGE: + return ">="; + case CompareSelectOperation::kLT: + return "<"; + case CompareSelectOperation::kLE: + return "<="; + default: + throw std::runtime_error("invalid compare select operator"); + } +} // TODO: change whether to include the parenthesis to the parent expression, // we need to look at the operator precedence to make the output simpler. -template +template < + typename Op, + typename std::enable_if())), + void>::value>::type* = nullptr> void visitBinaryOp( - BinaryOpNode* v, + NodePtr v, const std::string& op_str, IRPrinter* printer, bool parens = true) { @@ -133,28 +155,8 @@ void IRPrinter::visit(CompareSelectPtr v) { if (lhs_prec >= self_prec) { os() << ")"; } - switch (cmp_op) { - case CompareSelectOperation::kEQ: - os() << "=="; - break; - case CompareSelectOperation::kNE: - os() << "!="; - break; - case CompareSelectOperation::kGT: - os() << ">"; - break; - case CompareSelectOperation::kGE: - os() << ">="; - break; - case CompareSelectOperation::kLT: - os() << "<"; - break; - case CompareSelectOperation::kLE: - os() << "<="; - break; - default: - throw std::runtime_error("invalid compare select operator"); - } + + os() << to_string(cmp_op); if (rhs_prec >= self_prec) { os() << "("; @@ -204,11 +206,19 @@ static void formatImm(std::ostream& os, T v) { } } +static void formatIntSuffix(std::ostream& os, int64_t v) { + os << "ll"; +} + +template +static void formatIntSuffix(std::ostream& os, T v) {} + template < typename T, std::enable_if_t::value>* = nullptr> static void formatImm(std::ostream& os, T v) { os << +v; + formatIntSuffix(os, v); } // NOLINTNEXTLINE @@ -216,7 +226,7 @@ static void formatImm(std::ostream& os, T v) { void IRPrinter::visit(Name##ImmPtr v) { \ formatImm(os(), v->value()); \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT); #undef IMM_PRINT_VISIT void IRPrinter::visit(CastPtr v) { @@ -226,6 +236,13 @@ void IRPrinter::visit(CastPtr v) { os() << ")"; } +void IRPrinter::visit(BitCastPtr v) { + auto dtype = v->dtype(); + os() << "BitCast<" << dtype.ToCppString() << ">("; + v->src_value()->accept(this); + os() << ")"; +} + void IRPrinter::visit(VarPtr v) { os() << name_manager_.get_unique_name(v); } @@ -435,7 +452,7 @@ void IRPrinter::visit(FreePtr v) { void IRPrinter::visit(LetPtr v) { os() << dtypeToCppString(v->dtype()) << " " << *v->var(); os() << " = " << *v->value(); - os() << ";"; + os() << ";" << std::endl; } void IRPrinter::visit(CondPtr v) { @@ -541,7 +558,7 @@ std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) { } std::ostream& operator<<(std::ostream& stream, const Tensor& t) { - stream << std::to_string(&t); + stream << std::to_string(t); return stream; } @@ -564,7 +581,7 @@ void print(StmtPtr stmt) { } } -void print(const Tensor* t) { +void print(const Tensor& t) { std::cout << std::to_string(t); } @@ -585,20 +602,17 @@ std::string to_string(StmtPtr stmt) { return oss.str(); } -std::string to_string(const Tensor* t) { - if (!t) { - return "(null tensor)\n"; - } +std::string to_string(const Tensor& t) { std::ostringstream oss; // TODO: move this to Buf printer - oss << "Tensor " << t->buf()->name_hint() << "["; - for (const auto i : c10::irange(t->buf()->ndim())) { + oss << "Tensor " << t.buf()->name_hint() << "["; + for (const auto i : c10::irange(t.buf()->ndim())) { if (i != 0) { oss << ", "; } - oss << *t->buf()->dim(i); + oss << *t.buf()->dim(i); } - oss << "]:\n" << *t->stmt() << "\n"; + oss << "]:\n" << *t.stmt() << "\n"; return oss.str(); } } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h index e76dccab846a1..fb357a8fb79fa 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.h +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -34,9 +34,10 @@ class TORCH_API IRPrinter : public IRVisitor { void visit(RshiftPtr v) override; void visit(CompareSelectPtr v) override; #define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT); #undef IMM_PRINT_VISIT void visit(CastPtr v) override; + void visit(BitCastPtr v) override; void visit(VarPtr v) override; void visit(RampPtr v) override; void visit(LoadPtr v) override; @@ -83,6 +84,8 @@ class TORCH_API IRPrinter : public IRVisitor { }; protected: + std::string to_string(CompareSelectOperation op); + UniqueNameManager* name_manager() { return &name_manager_; } @@ -103,7 +106,7 @@ TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&); TORCH_API void print(ExprPtr expr); TORCH_API void print(StmtPtr stmt); -TORCH_API void print(const Tensor* t); +TORCH_API void print(const Tensor& t); } // namespace tensorexpr } // namespace jit @@ -119,5 +122,5 @@ using torch::jit::tensorexpr::Tensor; TORCH_API std::string to_string(ExprPtr expr); TORCH_API std::string to_string(StmtPtr stmt); -TORCH_API std::string to_string(const Tensor* t); +TORCH_API std::string to_string(const Tensor& t); } // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 3d849fec6d9db..3ce194325f08a 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -6,6 +6,70 @@ namespace torch { namespace jit { namespace tensorexpr { +// Creates a new Expr of the given type with the provided lhs and rhs. +inline ExprPtr newBinaryOpOfType( + IRNodeType expr_type, + ExprPtr lhs, + ExprPtr rhs, + bool option) { + switch (expr_type) { + // NOLINTNEXTLINE(bugprone-branch-clone) + case IRNodeType::kAdd: + return alloc(lhs, rhs); + case IRNodeType::kSub: + return alloc(lhs, rhs); + case IRNodeType::kMul: + return alloc(lhs, rhs); + case IRNodeType::kDiv: + return alloc
(lhs, rhs); + case IRNodeType::kMod: + return alloc(lhs, rhs); + case IRNodeType::kMax: + return alloc(lhs, rhs, option); + case IRNodeType::kMin: + return alloc(lhs, rhs, option); + case IRNodeType::kAnd: + return alloc(lhs, rhs); + case IRNodeType::kXor: + return alloc(lhs, rhs); + case IRNodeType::kLshift: + return alloc(lhs, rhs); + case IRNodeType::kRshift: + return alloc(lhs, rhs); + default: + LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); + return nullptr; + } +} + +template < + typename Op, + typename std::enable_if())), + void>::value>::type* = nullptr> +static ExprPtr mutateBinaryOp( + NodePtr v, + IRMutator* mutator, + bool option = false) { + ExprPtr lhs = v->lhs(); + ExprPtr rhs = v->rhs(); + ExprPtr lhs_new = lhs->accept_mutator(mutator); + ExprPtr rhs_new = rhs->accept_mutator(mutator); + + ExprPtr node = v; + + if (lhs != lhs_new || rhs != rhs_new) { + node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option); + } + + // Can only fold if both sides are constant. + if (!lhs_new->isConstant() || !rhs_new->isConstant()) { + return node; + } + + return evaluateOp(node); +} + // Simple recursive GCD. template T gcd(T a, T b) { @@ -35,8 +99,15 @@ void Term::sort() { if (dtype().is_floating_point()) { throw std::logic_error("reordering FP ops"); } + std::unordered_map str_repr_cache; std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { - return hasher_.hash(a) < hasher_.hash(b); + if (!str_repr_cache.count(a)) { + str_repr_cache[a] = std::to_string(a); + } + if (!str_repr_cache.count(b)) { + str_repr_cache[b] = std::to_string(b); + } + return str_repr_cache.at(a) < str_repr_cache.at(b); }); } @@ -52,8 +123,15 @@ void Polynomial::sort() { if (dtype().is_floating_point()) { throw std::logic_error("reordering FP ops"); } + std::unordered_map str_repr_cache; std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { - return hasher_.hash(a) < hasher_.hash(b); + if (!str_repr_cache.count(a)) { + str_repr_cache[a] = std::to_string(a); + } + if (!str_repr_cache.count(b)) { + str_repr_cache[b] = std::to_string(b); + } + return str_repr_cache.at(a) < str_repr_cache.at(b); }); } @@ -66,6 +144,18 @@ void MaxTerm::uniquefy() { return hasher_.hash(a) == hasher_.hash(b); }); variables_.resize(std::distance(variables_.begin(), it)); + + // Once we removed duplicates, sort terms alphabetically for stability. + std::unordered_map str_repr_cache; + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { + if (!str_repr_cache.count(a)) { + str_repr_cache[a] = std::to_string(a); + } + if (!str_repr_cache.count(b)) { + str_repr_cache[b] = std::to_string(b); + } + return str_repr_cache.at(a) < str_repr_cache.at(b); + }); } void MinTerm::uniquefy() { @@ -77,6 +167,18 @@ void MinTerm::uniquefy() { return hasher_.hash(a) == hasher_.hash(b); }); variables_.resize(std::distance(variables_.begin(), it)); + + // Once we removed duplicates, sort terms alphabetically for stability. + std::unordered_map str_repr_cache; + std::sort(variables_.begin(), variables_.end(), [&](ExprPtr a, ExprPtr b) { + if (!str_repr_cache.count(a)) { + str_repr_cache[a] = std::to_string(a); + } + if (!str_repr_cache.count(b)) { + str_repr_cache[b] = std::to_string(b); + } + return str_repr_cache.at(a) < str_repr_cache.at(b); + }); } // Handles optimization cases for Broadcast/Ramp +/- Broadcast/Ramp @@ -328,8 +430,7 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { // Otherwise this is a new polynomial with no scalar and two variable // terms. - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } // Adds are commutative. @@ -350,19 +451,17 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { // Simple Term with a scalar and variable type. if (scalar) { return alloc( - hasher_, - scalar, - alloc(hasher_, getImmediateByType(v->dtype(), 1), variable)); + hasher_, scalar, alloc(hasher_, immLike(v, 1), variable)); } // If LHS is neither Term not Polynomial, wrap it in a Term. if (!lhsTerm && !lhsPoly) { - lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, immLike(v, 1), lhs_new); } // Same for RHS. if (!rhsTerm && !rhsPoly) { - rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), rhs_new); + rhsTerm = alloc(hasher_, immLike(v, 1), rhs_new); } // If we now have a poly and a term, we can insert. @@ -378,8 +477,7 @@ ExprPtr PolynomialTransformer::mutate(AddPtr v) { } // If all else fails we have a new Polynomial with two new variable Terms. - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } ExprPtr PolynomialTransformer::subTerms( @@ -388,7 +486,7 @@ ExprPtr PolynomialTransformer::subTerms( bool negated) { // If RHS not already negated, negate it. if (!negated) { - ExprPtr minusOne = getImmediateByType(rhs->dtype(), -1); + ExprPtr minusOne = immLike(rhs, -1); ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhs->scalar())); rhs = alloc(hasher_, negateScalar, rhs->variables()); } @@ -427,8 +525,7 @@ ExprPtr PolynomialTransformer::subPolynomials( for (auto rt : rhs->variables()) { // Polynomials add their terms, so negate the RHS's Terms. - ExprPtr negated = evaluateOp( - alloc(getImmediateByType(rt->dtype(), -1), rt->scalar())); + ExprPtr negated = evaluateOp(alloc(immLike(rt, -1), rt->scalar())); TermPtr newRHS = alloc(hasher_, negated, rt->variables()); addOrUpdateTerm(varmap, newRHS); } @@ -492,7 +589,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { auto ret = subPolynomials(lhsPoly, rhsPoly); if (!ret) { // Cancelled out completely. - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } return ret; } @@ -503,8 +600,8 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Polynomial - Term. if (lhsPoly && rhsTerm) { // Negate the term. - ExprPtr negate = evaluateOp(alloc( - getImmediateByType(rhsTerm->dtype(), -1), rhsTerm->scalar())); + ExprPtr negate = + evaluateOp(alloc(immLike(rhsTerm, -1), rhsTerm->scalar())); TermPtr newTerm = alloc(hasher_, negate, rhsTerm->variables()); return insertTerm(lhsPoly, newTerm); } @@ -512,7 +609,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Term - Polynomial. if (rhsPoly && lhsTerm) { // Negate every part of the Polynomial. - ExprPtr minusOne = getImmediateByType(lhsTerm->dtype(), -1); + ExprPtr minusOne = immLike(lhsTerm, -1); ExprPtr negateScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); std::vector variables; @@ -543,7 +640,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { ExprPtr newScalar = evaluateOp(alloc(lhs_new, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. - ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + ExprPtr minusOne = immLike(rhsPoly, -1); std::vector variables; for (auto t : rhsPoly->variables()) { ExprPtr negate = evaluateOp(alloc(minusOne, t->scalar())); @@ -555,15 +652,14 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { if (lhsTerm && rhsScalar) { // Negate the constant. - ExprPtr negate = evaluateOp( - alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + ExprPtr negate = evaluateOp(alloc(immLike(rhs_new, -1), rhs_new)); return alloc(hasher_, negate, lhsTerm); } if (lhsScalar && rhsTerm) { // Negate the RHS Term. - ExprPtr negate = evaluateOp(alloc( - getImmediateByType(rhsTerm->scalar()->dtype(), -1), rhsTerm->scalar())); + ExprPtr negate = evaluateOp( + alloc(immLike(rhsTerm->scalar(), -1), rhsTerm->scalar())); return alloc( hasher_, lhs_new, alloc(hasher_, negate, rhsTerm->variables())); @@ -573,29 +669,24 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { if (lhsScalar) { // Create a negated term. return alloc( - hasher_, - lhs_new, - alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new)); + hasher_, lhs_new, alloc(hasher_, immLike(v, -1), rhs_new)); } if (rhsScalar) { // Negate the scalar. - ExprPtr negate = evaluateOp( - alloc(getImmediateByType(rhs_new->dtype(), -1), rhs_new)); + ExprPtr negate = evaluateOp(alloc(immLike(rhs_new, -1), rhs_new)); return alloc( - hasher_, - negate, - alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new)); + hasher_, negate, alloc(hasher_, immLike(v, 1), lhs_new)); } // no scalar... if (!lhsTerm && !lhsPoly) { - lhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), 1), lhs_new); + lhsTerm = alloc(hasher_, immLike(v, 1), lhs_new); } bool createdRHSnegated = false; if (!rhsTerm && !rhsPoly) { - rhsTerm = alloc(hasher_, getImmediateByType(v->dtype(), -1), rhs_new); + rhsTerm = alloc(hasher_, immLike(v, -1), rhs_new); createdRHSnegated = true; } @@ -612,7 +703,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { // Insert wrapper Term into negated RHS Poly. if (rhsPoly) { CHECK(lhsTerm); - ExprPtr minusOne = getImmediateByType(rhsPoly->dtype(), -1); + ExprPtr minusOne = immLike(rhsPoly, -1); ExprPtr newScalar = evaluateOp(alloc(minusOne, rhsPoly->scalar())); // Negate each term in the Polynomial RHS. @@ -626,8 +717,7 @@ ExprPtr PolynomialTransformer::mutate(SubPtr v) { return insertTerm(poly, lhsTerm); } - return alloc( - hasher_, getImmediateByType(v->dtype(), 0), lhsTerm, rhsTerm); + return alloc(hasher_, immLike(v, 0), lhsTerm, rhsTerm); } // Multiply two terms together, usually creating a new term with the variable @@ -828,7 +918,7 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { // Handle special case mul by 0. if (scalar && immediateEquals(scalar, 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // Catch cases of rounding (Div(A/B) * B). @@ -892,13 +982,11 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { // Multiplying Polynomial by variable can be wrapped in a term and handled // by polyByTerm also. if (lhsPoly) { - auto term = - alloc(hasher_, getImmediateByType(rhs_new->dtype(), 1), rhs_new); + auto term = alloc(hasher_, immLike(rhs_new, 1), rhs_new); return polyByTerm(lhsPoly, term); } if (rhsPoly) { - auto term = - alloc(hasher_, getImmediateByType(lhs_new->dtype(), 1), lhs_new); + auto term = alloc(hasher_, immLike(lhs_new, 1), lhs_new); return polyByTerm(rhsPoly, term); } @@ -912,8 +1000,7 @@ ExprPtr PolynomialTransformer::mutate(MulPtr v) { } // Two variables, create a new Term. - return alloc( - hasher_, getImmediateByType(v->dtype(), 1), lhs_new, rhs_new); + return alloc(hasher_, immLike(v, 1), lhs_new, rhs_new); } ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) { @@ -946,10 +1033,8 @@ ExprPtr factorizeDivision(ExprPtr lhs_new, ExprPtr rhs_new) { return nullptr; } - leftScalar = evaluateOp( - alloc
(leftScalar, getImmediateByType(leftScalar->dtype(), GCD))); - rightScalar = evaluateOp( - alloc
(rightScalar, getImmediateByType(rightScalar->dtype(), GCD))); + leftScalar = evaluateOp(alloc
(leftScalar, immLike(leftScalar, GCD))); + rightScalar = evaluateOp(alloc
(rightScalar, immLike(rightScalar, GCD))); if (lhsTerm) { lhs_new = alloc(lhsTerm->hasher(), leftScalar, lhsTerm->variables()); @@ -1025,12 +1110,12 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { // x % 1 == 0. if (rhs_new->isConstant() && immediateEquals(rhs_new, 1)) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // x % x => 0. if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } TermPtr lhsTerm = to(lhs_new); @@ -1047,13 +1132,13 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { if (rhs_new->isConstant() && immediateEquals( evaluateOp(alloc(lhsTerm->scalar(), rhs_new)), 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } // (x * y * z) % x => 0. for (auto component : lhsTerm->variables()) { if (hasher_.hash(component) == hasher_.hash(rhs_new)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } } @@ -1087,7 +1172,7 @@ ExprPtr PolynomialTransformer::mutate(ModPtr v) { immediateEquals( evaluateOp(alloc(lhsTerm->scalar(), rhsTerm->scalar())), 0)) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } } } @@ -1461,6 +1546,22 @@ ExprPtr PolynomialTransformer::mutate(IfThenElsePtr v) { return alloc(condition_new, true_value_new, false_value_new); } +ExprPtr PolynomialTransformer::mutate(AndPtr v) { + return mutateBinaryOp(v, this); +} + +ExprPtr PolynomialTransformer::mutate(XorPtr v) { + return mutateBinaryOp(v, this); +} + +ExprPtr PolynomialTransformer::mutate(LshiftPtr v) { + return mutateBinaryOp(v, this); +} + +ExprPtr PolynomialTransformer::mutate(RshiftPtr v) { + return mutateBinaryOp(v, this); +} + StmtPtr PolynomialBase::mutate(CondPtr v) { ExprPtr cond_old = v->condition(); StmtPtr true_old = v->true_stmt(); @@ -1744,7 +1845,7 @@ ExprPtr polyGCD(PolynomialPtr poly) { return nullptr; } - return getImmediateByType(poly->dtype(), GCD); + return immLike(poly, GCD); } // A ModRound is a div-mod-mul in which the divisor in div and multiplier in mul @@ -1863,9 +1964,10 @@ c10::optional isModRound(TermPtr e) { } if (!scalar) { - scalar = getImmediateByType(multiplier->dtype(), 1); + scalar = immLike(multiplier, 1); } + // TODO: this leaks memory! return new ModRound(scalar, denom, divisor, mod_divisor); } @@ -2076,8 +2178,20 @@ ExprPtr TermExpander::mutate(PolynomialPtr v) { std::vector addTerms; std::vector subTerms; + auto vars = v->variables(); + std::unordered_map str_repr_cache; + std::sort(vars.begin(), vars.end(), [&](ExprPtr a, ExprPtr b) { + if (!str_repr_cache.count(a)) { + str_repr_cache[a] = std::to_string(a); + } + if (!str_repr_cache.count(b)) { + str_repr_cache[b] = std::to_string(b); + } + return str_repr_cache.at(a) < str_repr_cache.at(b); + }); + // partition the terms into a list to add and list to subtract. - for (auto node : v->variables()) { + for (auto node : vars) { if (immediateIsNegative(node->scalar())) { subTerms.push_back(node); } else if (!immediateEquals(node->scalar(), 0)) { @@ -2130,23 +2244,23 @@ ExprPtr TermExpander::mutate(PolynomialPtr v) { } // Negate the term back to positive since we'll be subtracting it. - ExprPtr negated = evaluateOp(alloc( - getImmediateByType(node->scalar()->dtype(), -1), node->scalar())); + ExprPtr negated = + evaluateOp(alloc(immLike(node->scalar(), -1), node->scalar())); TermPtr newRHS = alloc(node->hasher(), negated, node->variables()); lastNode = alloc(lastNode, newRHS->accept_mutator(this)); } if (scalarWritten || immediateEquals(v->scalar(), 0)) { if (!lastNode) { - return getImmediateByType(v->dtype(), 0); + return immLike(v, 0); } return lastNode; } if (immediateIsNegative(v->scalar())) { // Negate the scalar and subtract. - ExprPtr negated = evaluateOp( - alloc(getImmediateByType(lastNode->dtype(), -1), v->scalar())); + ExprPtr negated = + evaluateOp(alloc(immLike(lastNode, -1), v->scalar())); lastNode = alloc(lastNode, evaluateOp(negated)); } else { // we want to avoid a cast to the scalar if it would happen. @@ -2213,7 +2327,7 @@ ExprPtr TermExpander::mutate(MinTermPtr v) { ExprPtr TermExpander::mutate(RoundOffPtr v) { TermPtr term = alloc( simplifier_->hasher(), - getImmediateByType(v->dtype(), 1), + immLike(v, 1), alloc
(v->lhs(), v->rhs()), v->rhs()); return term->accept_mutator(this); @@ -2221,8 +2335,10 @@ ExprPtr TermExpander::mutate(RoundOffPtr v) { ExprPtr buf_flat_size(BufPtr v) { std::vector dims = v->dims(); - - ExprPtr flattened = getImmediateByType(kInt, 1); + if (dims.size() == 0) { + return alloc(1); + } + ExprPtr flattened = immLike(dims[0], 1); for (auto& dim : dims) { flattened = alloc(flattened, dim); } @@ -2235,7 +2351,9 @@ ExprPtr buf_flat_size(BufPtr v) { StmtPtr TermExpander::mutate(AllocatePtr v) { BufPtr buf = v->buf(); BufPtr buf_new = to(v->buf()->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_new); + TORCH_INTERNAL_ASSERT( + buf_new, + buildErrorMessage("TermExpander mutation produced null for Buf.")); ExprPtr flattened = buf_flat_size(buf_new); if (flattened->isConstant() && immediateEquals(flattened, 0)) { @@ -2252,7 +2370,9 @@ StmtPtr TermExpander::mutate(AllocatePtr v) { StmtPtr TermExpander::mutate(FreePtr v) { BufPtr buf = v->buf(); BufPtr buf_new = to(v->buf()->accept_mutator(this)); - TORCH_INTERNAL_ASSERT(buf_new); + TORCH_INTERNAL_ASSERT( + buf_new, + buildErrorMessage("TermExpander mutation produced null for Buf.")); if (eliminated_allocations_.count(buf_new->base_handle())) { eliminated_allocations_.erase(buf_new->base_handle()); @@ -2553,7 +2673,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { return nullptr; } ExprPtr check_n_value = IRSimplifier::simplify( - alloc(rhsScalar, alloc(0), kGT)); + alloc(rhsScalar, immLike(rhsScalar, 0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } @@ -2588,7 +2708,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // range auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || @@ -2600,7 +2720,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 1) exprs: '(i+x)/n' => 'x/n' ExprPtr sign_check = - IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + IRSimplifier::simplify(alloc(main, immLike(main, 0), kGE)); ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); ExprPtr mod_check = IRSimplifier::simplify( alloc(alloc(main_mod, end), rhsScalar, kLE)); @@ -2611,6 +2731,7 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 2 exprs: '(i+j*n)/n' => 'j' auto ret_var = to(ret); + // FIXME: Allow any integral type. if (ret_var && ret_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(ret_var); @@ -2619,8 +2740,8 @@ ExprPtr distributeDiv(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { } // check if j is not negative - sign_check = IRSimplifier::simplify( - alloc(got->second.first, alloc(0), kGE)); + sign_check = IRSimplifier::simplify(alloc( + got->second.first, immLike(got->second.first, 0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return ret_var; } @@ -2670,7 +2791,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { return nullptr; } ExprPtr check_n_value = IRSimplifier::simplify( - alloc(rhsScalar, alloc(0), kGT)); + alloc(rhsScalar, immLike(rhsScalar, 0), kGT)); if (!immediateEquals(check_n_value, 1)) { return nullptr; } @@ -2707,7 +2828,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // range auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (!check_start->isConstant() || !check_end->isConstant() || @@ -2717,7 +2838,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 1) exprs: '(i+x)%n' => 'i+x%n' ExprPtr sign_check = - IRSimplifier::simplify(alloc(main, alloc(0), kGE)); + IRSimplifier::simplify(alloc(main, immLike(main, 0), kGE)); ExprPtr main_mod = IRSimplifier::simplify(alloc(main, rhsScalar)); ExprPtr mod_check = IRSimplifier::simplify( alloc(alloc(main_mod, end), rhsScalar, kLE)); @@ -2729,6 +2850,7 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { // simplify type 2) exprs: '(i+j*n)%n' => 'i' ExprPtr main_div = IRSimplifier::simplify(alloc
(main, rhsScalar)); auto j_var = to(main_div); + // FIXME: Allow any integral type. if (j_var && j_var->dtype() == kInt) { // retrieve j's range info auto got = var_bound_info.find(j_var); @@ -2737,8 +2859,8 @@ ExprPtr distributeMod(ExprPtr lhs, ExprPtr rhs, VarBoundInfo var_bound_info) { } // check if j is not negative - sign_check = IRSimplifier::simplify( - alloc(got->second.first, alloc(0), kGE)); + sign_check = IRSimplifier::simplify(alloc( + got->second.first, immLike(got->second.first, 0), kGE)); if (sign_check->isConstant() && immediateEquals(sign_check, 1)) { return var_key; } @@ -2789,7 +2911,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) { auto start = got->second.first; auto end = got->second.second; ExprPtr check_start = IRSimplifier::simplify( - alloc(start, alloc(0), kGE)); + alloc(start, immLike(start, 0), kGE)); ExprPtr check_end = IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (check_start->isConstant() && check_end->isConstant() && @@ -2822,6 +2944,49 @@ bool exprEquals(ExprPtr A, ExprPtr B) { } } +ExprPtr IRSimplifier::simplify(ExprPtr e) { + GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(e)); + SimplifierUnderContext ctxsimplifier; + e = e->accept_mutator(&ctxsimplifier); + + PolynomialTransformer simplifier; + e = e->accept_mutator(&simplifier); + + // There may be terms left in the IR, expand them. + TermExpander expander(&simplifier); + e = e->accept_mutator(&expander); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + if (!expander.check_safe()) { + throw malformed_input("eliminated null Allocation without free"); + } + + GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(e)); + return e; +} + +StmtPtr IRSimplifier::simplify(StmtPtr s) { + GRAPH_DEBUG("(Simplifier) Original: ", std::to_string(s)); + SimplifierUnderContext ctxsimplifier; + s = s->accept_mutator(&ctxsimplifier); + + PolynomialTransformer simplifier; + s = s->accept_mutator(&simplifier); + if (s == nullptr) { + GRAPH_DEBUG("(Simplifier) Simplified: NULL"); + return nullptr; + } + + // There may be terms left in the IR, expand them. + TermExpander expander(&simplifier); + s = s->accept_mutator(&expander); + if (!expander.check_safe()) { + throw malformed_input("eliminated null Allocation without free"); + } + + GRAPH_DEBUG("(Simplifier) Simplified: ", std::to_string(s)); + return s; +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 6281b77349b37..11d004f395ed1 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -55,7 +55,7 @@ Dtype promoteTypesVec(std::vector& v) { template Dtype promoteTypesMap( ExprPtr s, - std::unordered_map& m) { + std::unordered_map& m) { Dtype t = s->dtype(); bool first = true; for (auto& e : m) { @@ -69,12 +69,12 @@ Dtype promoteTypesMap( } template -Dtype promoteTypesVar(ExprType* e) { +Dtype promoteTypesVar(ExprType e) { return e->dtype(); } template -Dtype promoteTypesVar(ExprType* e, Args... es) { +Dtype promoteTypesVar(ExprType e, Args... es) { Dtype lhs = e->dtype(); Dtype rhs = promoteTypesVar(es...); if (e->isConstant()) { @@ -84,42 +84,6 @@ Dtype promoteTypesVar(ExprType* e, Args... es) { return promoteTypes(lhs, rhs); } -// Creates a new Expr of the given type with the provided lhs and rhs. -inline ExprPtr newBinaryOpOfType( - IRNodeType expr_type, - ExprPtr lhs, - ExprPtr rhs, - bool option) { - switch (expr_type) { - // NOLINTNEXTLINE(bugprone-branch-clone) - case IRNodeType::kAdd: - return alloc(lhs, rhs); - case IRNodeType::kSub: - return alloc(lhs, rhs); - case IRNodeType::kMul: - return alloc(lhs, rhs); - case IRNodeType::kDiv: - return alloc
(lhs, rhs); - case IRNodeType::kMod: - return alloc(lhs, rhs); - case IRNodeType::kMax: - return alloc(lhs, rhs, option); - case IRNodeType::kMin: - return alloc(lhs, rhs, option); - case IRNodeType::kAnd: - return alloc(lhs, rhs); - case IRNodeType::kXor: - return alloc(lhs, rhs); - case IRNodeType::kLshift: - return alloc(lhs, rhs); - case IRNodeType::kRshift: - return alloc(lhs, rhs); - default: - LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); - return nullptr; - } -} - // Uses the evaluator to fold an Expression with constant terms. // E.g. evaluateOp(Add(3, 4)) => 7. // Expr v must not have any unbound Vars. @@ -133,7 +97,7 @@ inline ExprPtr evaluateOp(ExprPtr v) { Type val = eval.value(); \ return getImmediateByType(v->dtype().scalar_type(), val); \ } - AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: LOG(FATAL) << "Unsupported datatype: " << v->dtype(); @@ -498,21 +462,13 @@ class TORCH_API PolynomialTransformer : public PolynomialBase { ExprPtr mutate(ModPtr v) override; - ExprPtr mutate(AndPtr v) override { - return mutateBinaryOp(v, this); - } + ExprPtr mutate(AndPtr v) override; - ExprPtr mutate(XorPtr v) override { - return mutateBinaryOp(v, this); - } + ExprPtr mutate(XorPtr v) override; - ExprPtr mutate(LshiftPtr v) override { - return mutateBinaryOp(v, this); - } + ExprPtr mutate(LshiftPtr v) override; - ExprPtr mutate(RshiftPtr v) override { - return mutateBinaryOp(v, this); - } + ExprPtr mutate(RshiftPtr v) override; ExprPtr mutate(MaxPtr v) override; @@ -526,30 +482,6 @@ class TORCH_API PolynomialTransformer : public PolynomialBase { ExprPtr mutate(IfThenElsePtr v) override; - template - static ExprPtr mutateBinaryOp( - BinaryOpNode* v, - IRMutator* mutator, - bool option = false) { - ExprPtr lhs = v->lhs(); - ExprPtr rhs = v->rhs(); - ExprPtr lhs_new = lhs->accept_mutator(mutator); - ExprPtr rhs_new = rhs->accept_mutator(mutator); - - ExprPtr node = v; - - if (lhs != lhs_new || rhs != rhs_new) { - node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option); - } - - // Can only fold if both sides are constant. - if (!lhs_new->isConstant() || !rhs_new->isConstant()) { - return node; - } - - return evaluateOp(node); - } - static ExprPtr simplify(ExprPtr e); static ExprHandle simplify(const ExprHandle& e); static StmtPtr simplify(StmtPtr e); @@ -596,47 +528,11 @@ class TORCH_API TermExpander : public PolynomialBase { class TORCH_API IRSimplifier { public: - static ExprPtr simplify(ExprPtr e) { - SimplifierUnderContext ctxsimplifier; - e = e->accept_mutator(&ctxsimplifier); - - PolynomialTransformer simplifier; - e = e->accept_mutator(&simplifier); - - // There may be terms left in the IR, expand them. - TermExpander expander(&simplifier); - e = e->accept_mutator(&expander); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - if (!expander.check_safe()) { - throw malformed_input("eliminated null Allocation without free"); - } - - return e; - } - + static StmtPtr simplify(StmtPtr s); + static ExprPtr simplify(ExprPtr e); static ExprHandle simplify(const ExprHandle& e) { return ExprHandle(simplify(e.node())); } - - static StmtPtr simplify(StmtPtr s) { - SimplifierUnderContext ctxsimplifier; - s = s->accept_mutator(&ctxsimplifier); - - PolynomialTransformer simplifier; - s = s->accept_mutator(&simplifier); - if (s == nullptr) { - return nullptr; - } - - // There may be terms left in the IR, expand them. - TermExpander expander(&simplifier); - s = s->accept_mutator(&expander); - if (!expander.check_safe()) { - throw malformed_input("eliminated null Allocation without free"); - } - - return s; - } }; // Flattens the buf and performs the simplifier on the flattened dims. diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.cpp b/torch/csrc/jit/tensorexpr/ir_verifier.cpp index c88e92c9a7a82..f31a935291c33 100644 --- a/torch/csrc/jit/tensorexpr/ir_verifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_verifier.cpp @@ -9,8 +9,19 @@ namespace torch { namespace jit { namespace tensorexpr { -template -void verifyBitwiseOp(const BitwiseOpNode* v, IRVerifier* verifier) { +namespace detail { +template +void deducer(BinaryOpNode); + +bool deducer(...); +} // namespace detail + +template < + typename D, + typename std::enable_if())), + void>::value>::type* = nullptr> +void verifyBitwiseOp(NodePtr v, IRVerifier* verifier) { if (!v->lhs()->dtype().is_integral()) { throw unsupported_dtype(); } @@ -108,7 +119,19 @@ void IRVerifier::visit(IfThenElsePtr v) { } void IRVerifier::visit(IntrinsicsPtr v) { + if (v->op_type() == kIsNan) { + if (v->dtype().scalar_type() != c10::kInt) { + throw malformed_ir("bad dtype in intrinsic arg"); + } + IRVisitor::visit(v); + return; + } // TODO: add a check for OpArgCount and op_type + for (auto const& param : v->params()) { + if (param->dtype() != v->dtype()) { + throw malformed_ir("bad dtype in intrinsic arg"); + } + } IRVisitor::visit(v); } diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp index 9066544bd2291..9489422b66ebe 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -11,8 +11,12 @@ namespace torch { namespace jit { namespace tensorexpr { -template -static void visit_binary_op(BinaryOpNode* v, IRVisitor* visitor) { +template < + typename Op, + typename std::enable_if())), + void>::value>::type* = nullptr> +static void visit_binary_op(NodePtr v, IRVisitor* visitor) { v->lhs()->accept(visitor); v->rhs()->accept(visitor); } @@ -75,7 +79,7 @@ void IRVisitor::visit(CompareSelectPtr v) { // NOLINTNEXTLINE #define IMM_VISIT(Type, Name) \ void IRVisitor::visit(Name##ImmPtr v) {} -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT); #undef IMM_VISIT void IRVisitor::visit(CastPtr v) { diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h index 001725f961619..e54786b2f9036 100644 --- a/torch/csrc/jit/tensorexpr/ir_visitor.h +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -26,7 +26,7 @@ class TORCH_API IRVisitor { #define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##ImmPtr v); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT) #undef IMM_PRINT_VISIT virtual void visit(CastPtr v); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index faacd022e7e0b..a86cb33a1b8bd 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -33,7 +34,10 @@ static bool checkTypes(const ScalarType highType, const int typeConstraints) { } // assume JIT not supporting complex and qint yet - TORCH_INTERNAL_ASSERT((typeConstraints & (kQintTypes | kComplexTypes)) == 0); + TORCH_INTERNAL_ASSERT( + (typeConstraints & (kQintTypes | kComplexTypes)) == 0, + buildErrorMessage( + "Qint and Complex types are not supported in the fuser.")); return false; } @@ -48,7 +52,7 @@ static ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) { case ScalarType::Name: \ e = cast(e); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); @@ -62,6 +66,19 @@ namespace torch { namespace jit { namespace tensorexpr { +std::string buildErrorMessage(const std::string& s) { + static const std::string generic_error_message = + "This error occured in the fuser. You can turn off the fuser with " + "torch._C._jit_override_can_fuse_on_cpu(False)"; + if (s.empty()) { + return generic_error_message; + } + if (s.back() == '.') { + return s + " " + generic_error_message; + } + return s + ". " + generic_error_message; +} + static int te_cuda_pointwise_loop_levels = -1; static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; @@ -163,13 +180,18 @@ c10::optional pickDeviceType(const std::shared_ptr& graph) { for (auto const& input : node->inputs()) { if (auto tt = input->type()->cast()) { if (auto inputDevice = tt->device()) { - TORCH_INTERNAL_ASSERT(!device || *device == *inputDevice); + TORCH_INTERNAL_ASSERT( + !device || *device == *inputDevice, + buildErrorMessage( + "Different devices specified for inputs to the fuser.")); device = inputDevice; } } } } - TORCH_INTERNAL_ASSERT(device); + TORCH_INTERNAL_ASSERT( + device, + buildErrorMessage("Could not find device in fuser graph inputs.")); return device; } @@ -201,11 +223,11 @@ c10::optional getTensorInfoJit(torch::jit::Value* v) { c10::optional getTensorInfo(BufHandle b) { std::vector dims; for (auto dim : b.dims()) { - auto val = to(dim.node()); + auto val = intValue(dim.node()); if (!val) { return c10::nullopt; } - dims.push_back(val->value()); + dims.push_back(*val); } return TensorInfo{dims, static_cast(b.dtype().scalar_type())}; } @@ -355,7 +377,9 @@ bool matmulIsSupported(const torch::jit::Node* node) { void annotateInputShapes( const std::shared_ptr& graph, const std::vector>& example_inputs) { - TORCH_INTERNAL_ASSERT(graph->inputs().size() == example_inputs.size()); + TORCH_INTERNAL_ASSERT( + graph->inputs().size() == example_inputs.size(), + buildErrorMessage("Given inputs do not match the fuser graph inputs.")); for (size_t idx = 0; idx < example_inputs.size(); idx++) { if (auto t = example_inputs[idx]) { auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t); @@ -395,7 +419,7 @@ ExprHandle tensorOrConstant( return constant(v); } -size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { +int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing idx = list_size + idx; @@ -404,7 +428,7 @@ size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { if (idx < 0 || idx >= list_size) { AT_ERROR("Invalid index ", idx, " for list_size", list_size); } - return static_cast(idx); + return idx; } ExprHandle broadcast(BufHandle b, const std::vector& axes) { @@ -440,8 +464,8 @@ std::vector computeIndicesToBroadcast( auto axisIt = outputAxes.rbegin(); auto sizeIt = inputSizes.rbegin(); while (sizeIt != inputSizes.rend()) { - auto const& size = sizeIt->AsNode(); - if (size && size->value() == 1) { + auto const& size = intValue(*sizeIt); + if (size && *size == 1) { bcast.emplace_back(0); } else { bcast.emplace_back(*axisIt); @@ -453,6 +477,11 @@ std::vector computeIndicesToBroadcast( return bcast; } +bool isScalar(ExprHandle e) { + auto n = e.node(); + return n->isConstant() || to(n); +} + void promoteInputs(std::vector& inputs, const int typeConstraints) { if (inputs.empty()) { return; @@ -461,7 +490,16 @@ void promoteInputs(std::vector& inputs, const int typeConstraints) { // Find the highest type among the inputs. ScalarType highType = inputs[0].dtype().scalar_type(); for (auto input : inputs) { - highType = promoteTypes(highType, input.dtype().scalar_type()); + auto inputType = input.dtype().scalar_type(); + if (isScalar(input)) { + if (isIntegralType(highType, false) && isFloatingType(inputType)) { + highType = c10::get_default_dtype_as_scalartype(); + } else if (highType == c10::kBool) { + highType = inputType; + } + } else { + highType = promoteTypes(highType, inputType); + } } if (!checkTypes(highType, typeConstraints)) { @@ -488,7 +526,7 @@ ExprHandle demoteOutput( #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ return cast(e); - AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE case ScalarType::Bool: return cast(e); @@ -510,7 +548,9 @@ static at::ScalarType tensorType(BufPtr b) { std::vector bufferSizes(BufPtr b) { std::vector sizes; for (size_t i = 0; i < b->ndim(); i++) { - sizes.push_back(to(b->dim(i))->value()); + auto dim = intValue(b->dim(i)); + TORCH_INTERNAL_ASSERT(dim, buildErrorMessage("Non-constant buf dims")); + sizes.push_back(*dim); } return sizes; } @@ -528,7 +568,8 @@ ExprHandle TensorExprKernel::chunk( std::vector indices; for (size_t i = 0; i < axes.size(); ++i) { if (i == norm_dim) { - indices.push_back(axes[i] + IntImm::make((int)chunkIdx * (int)step)); + indices.push_back( + axes[i] + ExprHandle(immLike(axes[i], chunkIdx * step))); } else { indices.push_back(axes[i]); } @@ -627,7 +668,7 @@ std::vector TensorExprKernel::sizesFromVaryingShape( const c10::VaryingShape& shape) { std::vector dims; for (const auto i : c10::irange(*shape.size())) { - dims.push_back(IntImm::make(*shape[i])); + dims.push_back(*shape[i]); } return dims; } @@ -649,7 +690,7 @@ std::vector TensorExprKernel::sizesForValue( if (v->type()->isSubtypeOf(FloatType::get()) || v->type()->isSubtypeOf(IntType::get())) { - return {1}; + return {int64_t{1}}; } if (v->type()->isSubtypeOf(NoneType::get())) { return {}; @@ -802,10 +843,13 @@ std::vector TensorExprKernel::inferSizesForValue( throw std::runtime_error("Empty input list is passed to aten::cat"); } - TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT( + n->input(1)->node()->kind() == prim::Constant, + buildErrorMessage( + "aten::cat op's dim input is not constant in fuser.")); int64_t dim = n->input(1)->node()->i(attr::value); auto shape = sizesForValue(inputs[0]); - size_t norm_dim = normalizeAndCheckIndex(dim, shape.size()); + auto norm_dim = normalizeAndCheckIndex(dim, shape.size()); ExprHandle concat_dim_size = 0; for (auto input : inputs) { concat_dim_size = concat_dim_size + sizesForValue(input)[norm_dim]; @@ -845,7 +889,8 @@ ExprHandle promoteIntegerToDefaultType(const ExprHandle& e) { // We intend to promote Integers to floating-point types TORCH_INTERNAL_ASSERT( - !c10::isIntegralType(defaultType, /*includeBool*/ true)); + !c10::isIntegralType(defaultType, /*includeBool*/ true), + buildErrorMessage("Non-integer type")); return Cast::make( Dtype( @@ -874,11 +919,11 @@ ExprHandle clamp( } static bool isOne(ExprHandle e) { - auto const& n = e.AsNode(); + auto const& n = intValue(e); if (!n) { return false; } - return n->value() == 1; + return *n == 1; } std::pair, bool> broadcastShapesImpl( @@ -960,7 +1005,7 @@ std::vector TensorExprKernel::broadcastShapesMut( return res.first; } -Tensor* computeOneOperand( +Tensor computeOneOperand( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -981,7 +1026,7 @@ Tensor* computeOneOperand( }); } -Tensor* computeTwoOperand( +Tensor computeTwoOperand( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -1004,7 +1049,7 @@ Tensor* computeTwoOperand( }); } -Tensor* computeTwoOperandWithAlpha( +Tensor computeTwoOperandWithAlpha( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -1028,7 +1073,7 @@ Tensor* computeTwoOperandWithAlpha( }); } -Tensor* computeConditionWithTwoOperand( +Tensor computeConditionWithTwoOperand( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -1055,7 +1100,7 @@ Tensor* computeConditionWithTwoOperand( }); } -Tensor* computeThreeOperand( +Tensor computeThreeOperand( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -1083,7 +1128,7 @@ Tensor* computeThreeOperand( return demoteOutput(compute, outputType); }); } -Tensor* computeFourOperand( +Tensor computeFourOperand( const std::string& name, const std::vector& inputValues, const std::vector& outputShape, @@ -1121,7 +1166,8 @@ std::pair> processCatList( std::vector nonEmptyInputs; for (auto buf : bufList) { bufInputs.push_back(buf); - TORCH_INTERNAL_ASSERT(buf.node()->dims().size() > 0); + TORCH_INTERNAL_ASSERT( + buf.node()->dims().size() > 0, buildErrorMessage("Invalid buf rank")); if (buf.node()->dims().size() == 1 && immediateAs(buf.node()->dim(0)) == 0) { continue; @@ -1135,7 +1181,8 @@ std::pair> processCatList( } return {highType, nonEmptyInputs}; } -Tensor* computeCatWoConditionals( + +Tensor computeCatWoConditionals( const std::vector& inputs, const std::vector& outputShape) { // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) @@ -1164,13 +1211,12 @@ Tensor* computeCatWoConditionals( auto output_buf = alloc("aten_cat", output_sizes_expr, ToDtype(high_type)); if (non_empty_inputs.size() == 0) { - return new Tensor( + return Tensor( output_buf, alloc(std::vector({}))); } int64_t concat_dim = c10::get(arg_dim); - size_t norm_concat_dim = - normalizeAndCheckIndex(concat_dim, outputShape.size()); + auto norm_concat_dim = normalizeAndCheckIndex(concat_dim, outputShape.size()); auto gen_code_for_input = [&](const BufHandle& inp, size_t inp_pos, @@ -1181,7 +1227,8 @@ Tensor* computeCatWoConditionals( std::vector store_indices(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { for_vars[i] = alloc( - "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt); + "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), + dims[i].dtype()); load_indices[i] = for_vars[i]; if (i == norm_concat_dim) { store_indices[i] = alloc(for_vars[i], concat_dim_size); @@ -1194,8 +1241,8 @@ Tensor* computeCatWoConditionals( auto load_promoted = promoteToDtype(ExprHandle(load_expr), high_type); StmtPtr st = alloc(output_buf, store_indices, load_promoted.node()); for (size_t i = dims.size(); i > 0; --i) { - st = - alloc(for_vars[i - 1], alloc(0), dims[i - 1].node(), st); + st = alloc( + for_vars[i - 1], immLike(dims[i - 1], 0), dims[i - 1].node(), st); } return st; }; @@ -1206,17 +1253,17 @@ Tensor* computeCatWoConditionals( auto input_dims = ExprVectorToExprHandleVector(non_empty_inputs[i].node()->dims()); if (concat_dim_size == nullptr) { - concat_dim_size = alloc(0); + concat_dim_size = immLike(input_dims[norm_concat_dim], 0); } block->append_stmt(gen_code_for_input( non_empty_inputs[i], i, concat_dim_size, input_dims)); concat_dim_size = alloc(concat_dim_size, input_dims[norm_concat_dim].node()); } - return new Tensor(output_buf, IRSimplifier::simplify(block)); + return Tensor(output_buf, IRSimplifier::simplify(block)); } -Tensor* computeCat( +Tensor computeCat( const std::vector& inputs, const std::vector& outputShape, at::Device device) { @@ -1238,7 +1285,7 @@ Tensor* computeCat( } int64_t dim_ = c10::get(argDim); - size_t dim = normalizeAndCheckIndex(dim_, axes.size()); + auto dim = normalizeAndCheckIndex(dim_, axes.size()); // Promote input types. // Note that we need to consider all inputs, including empty - they // also affect the resultant dtype. @@ -1258,25 +1305,25 @@ Tensor* computeCat( std::vector newAxes(axes.begin(), axes.end()); ExprHandle load = promoteToDtype( tensorOrConstant(nonEmptyInputs[0], newAxes), highType); - size_t offset = to(nonEmptyInputs[0].node()->dim(dim))->value(); - newAxes[dim] = newAxes[dim] - IntImm::make(offset); + auto offset = *intValue(nonEmptyInputs[0].node()->dim(dim)); + newAxes[dim] = newAxes[dim] - ExprHandle(immLike(newAxes[dim], offset)); for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) { auto input = nonEmptyInputs[ii]; load = ifThenElse( - CompareSelect::make(axes[dim], IntImm::make(offset), kLT), + CompareSelect::make(axes[dim], offset, kLT), load, promoteToDtype(tensorOrConstant(input, newAxes), highType)); - offset += to(input.node()->dim(dim))->value(); - newAxes[dim] = axes[dim] - IntImm::make(offset); + offset += *intValue(input.node()->dim(dim)); + newAxes[dim] = axes[dim] - ExprHandle(immLike(axes[dim], offset)); } return load; }); } -Tensor* computeConv2d( +Tensor computeConv2d( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -1319,10 +1366,10 @@ Tensor* computeConv2d( dilation[0], dilation[1], groups}); - return new Tensor(ResultBuf.node(), s); + return Tensor(ResultBuf.node(), s); } -Tensor* tensorexpr::computeOperandValue( +Tensor tensorexpr::computeOperandValue( c10::Symbol op, const std::vector& inputs, const std::vector& outputShape, @@ -1333,7 +1380,9 @@ Tensor* tensorexpr::computeOperandValue( auto add_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) + boolToInteger(rhs); }; - TORCH_INTERNAL_ASSERT(inputs.size() == 2 || inputs.size() == 3); + TORCH_INTERNAL_ASSERT( + inputs.size() == 2 || inputs.size() == 3, + buildErrorMessage("Invalid number of input operands")); return (inputs.size() > 2) ? computeTwoOperandWithAlpha( "aten_add", inputs, outputShape, outputType, add_lambda) @@ -1345,7 +1394,9 @@ Tensor* tensorexpr::computeOperandValue( // NB: sub isn't supported on boolean, no need to promote to integer. return lhs - rhs; }; - TORCH_INTERNAL_ASSERT(inputs.size() == 2 || inputs.size() == 3); + TORCH_INTERNAL_ASSERT( + inputs.size() == 2 || inputs.size() == 3, + buildErrorMessage("Invalid number of input operands")); return (inputs.size() > 2) ? computeTwoOperandWithAlpha( "aten_sub", inputs, outputShape, outputType, sub_lambda) @@ -2108,7 +2159,8 @@ Tensor* tensorexpr::computeOperandValue( outputShape, outputType, [outputType](const ExprHandle& a) { - TORCH_INTERNAL_ASSERT(outputType); + TORCH_INTERNAL_ASSERT( + outputType, buildErrorMessage("Output type is null.")); return Cast::make(ToDtype(*outputType), a); }); } break; @@ -2227,7 +2279,9 @@ Tensor* tensorexpr::computeOperandValue( "aten_transpose", c10::fmap(outputShape), [&](std::vector axes) { - TORCH_INTERNAL_ASSERT(axes.size() <= 1); + TORCH_INTERNAL_ASSERT( + axes.size() <= 1, + buildErrorMessage("Invalid axes size in transpose")); return A.load(axes); }); } @@ -2319,12 +2373,12 @@ Tensor* tensorexpr::computeOperandValue( ExprHandle cur_stride = 1; std::vector dims, indices; for (size_t idx = 0; idx < view_dims.size(); idx++) { - dims.push_back(alloc(view_dims[idx])); + dims.push_back(alloc(view_dims[idx])); indices.push_back(axes[idx].node()); } ExprHandle flat_idx = ExprHandle(flatten_index(dims, indices)); std::vector orig_buf_indexes(A.ndim(), ExprHandle(0)); - ExprHandle stride = IntImm::make(1); + ExprHandle stride = ExprHandle(immLike(flat_idx, 1)); for (size_t idx = 0; idx < A.ndim(); idx++) { size_t dim_idx = A.ndim() - idx - 1; // We don't need to generate mod-div for the first dimension - @@ -2391,7 +2445,7 @@ c10::optional findDtypeForValue(const torch::jit::Value* v) { return c10::nullopt; } -Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { +Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) { auto inputs = v->node()->inputs(); auto op = v->node()->kind(); @@ -2487,6 +2541,86 @@ void fuseAllLoops(StmtPtr st) { } } +// Compute the trip count of a loop if it is a constant. +c10::optional tripCount(ForPtr loop) { + auto tc = IRSimplifier::simplify( + cast(ExprHandle(loop->stop()) - ExprHandle(loop->start()))); + if (auto val = to(tc.node())) { + return val->value(); + } + return c10::nullopt; +} + +// Prune innermost loops until iterations satisfies a minimum grain size. +static void pruneByGrainSize(std::vector& loops) { + constexpr int64_t minGrainSize = 32768; + int64_t grainSize = 1; + for (int64_t i = loops.size(); i > 0; i--) { + auto tc = tripCount(loops[i - 1]); + if (!tc) { + break; + } + grainSize *= *tc; + if (grainSize < minGrainSize) { + loops.pop_back(); + } + } +} + +// Retain enough outermost loops to fill the number of threads. +static void pruneByThreadCount(std::vector& loops) { + int64_t trips = 1; + auto threads = at::get_num_threads(); + auto it = loops.begin(); + for (; it != loops.end(); it++) { + if (trips >= threads) { + break; + } + auto tc = tripCount(*it); + if (!tc) { + break; + } + trips *= *tc; + } + loops.erase(it, loops.end()); +} + +// Flatten and parallelize outer loops, subject to a minimum number of elements +// in the inner loop, and a maximum level of thread-level parallelism in the +// outer loops. +template +static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) { + for (auto const& buf : bufs) { + auto loops = l.getLoopStmtsFor(buf); + pruneByGrainSize(loops); + pruneByThreadCount(loops); + + // There are no loops to parallelize; give up. + if (loops.size() == 0) { + continue; + } + // The loop nest contains a reduction; give up. + auto reductions = NodeFinder::find(loops[0]); + if (reductions.size() > 0) { + continue; + } + // The loop nest has loop carried dependences; give up. + if (LoopNest::hasLoopCarriedDependence(loops[0])) { + continue; + } + // Try to flatten the outer loops and parallelize them if successful. + ForPtr flattened = nullptr; + if (loops.size() == 1) { + flattened = loops[0]; + } else { + LoopNest::flatten(loops, &flattened); + } + if (flattened) { + flattened->set_parallel(); + } + } +} + StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { torch::jit::tensorexpr::LoopNest l(st, bufOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); @@ -2528,6 +2662,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { if (backendType == kLLVMCodeGen) { fuseAllLoops(l.root_stmt()); GRAPH_DEBUG("after fuse", *l.root_stmt()); + parallelizeOuterLoops(l, bufOutputs_); + GRAPH_DEBUG("after parallelize", *l.root_stmt()); } if (backendType == kCudaCodeGen) { @@ -2588,7 +2724,11 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { blockSize = default_uint8_blocksize; } std::vector loops = l.getLoopStmtsFor(buf); - TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); + TORCH_INTERNAL_ASSERT( + !loops.empty(), + buildErrorMessage( + "No loops found for the buffer " + buf->name_hint() + + " in the fuser.")); ForPtr flattened = nullptr; LoopNest::flatten(loops, &flattened); assert(flattened); @@ -2602,9 +2742,13 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { } l.prepareForCodegen(); + GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt()); + l.simplify(); + GRAPH_DEBUG("after simplification", *l.root_stmt()); if (backendType == kLLVMCodeGen && !hasReduction) { l.vectorizeInnerLoops(); + GRAPH_DEBUG("after vectorization", *l.root_stmt()); } StmtPtr stmt = l.root_stmt(); @@ -2698,14 +2842,14 @@ static std::vector toExprHandles(const std::vector& sizes) { std::vector dims; dims.reserve(sizes.size()); for (auto const& size : sizes) { - dims.emplace_back(IntImm::make(size)); + dims.emplace_back(size); } return dims; } -Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) { +Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) { auto const& t = input->type(); - Tensor* result = nullptr; + Tensor result(nullptr, nullptr); switch (t->kind()) { case TypeKind::TensorType: { auto tt = input->type()->cast(); @@ -2730,8 +2874,7 @@ Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) { std::vector inputTensorDims; for (size_t i = 0; i < *tt->sizes().size(); i++) { auto const size = *tt->sizes()[i]; - inputTensorDims.emplace_back( - DimArg(IntImm::make(size), "i" + c10::to_string(i))); + inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i))); } auto const strides = tt->strides(); result = Compute( @@ -2740,12 +2883,11 @@ Tensor* TensorExprKernel::bindInput(const torch::jit::Value* input) { [&](const std::vector& axes) { ExprHandle idx = 0; for (size_t i = 0; i < axes.size(); i++) { - idx = idx + axes[i] * IntImm::make(*strides[i]); + idx = idx + axes[i] * *strides[i]; } return inBuffer.load(idx); }); - bufs_.emplace(input, result->buf()); - + bufs_.emplace(input, result.buf()); bufferArgs_.emplace_back(inBuffer); break; } @@ -2800,9 +2942,12 @@ bool denseAndNonOverlapping( return (strides == at::infer_dense_strides(sizes, strides)); } -Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { +Tensor TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { const TensorTypePtr& tt = v->type()->expect(); - TORCH_INTERNAL_ASSERT(bufs_.count(v)); + TORCH_INTERNAL_ASSERT( + bufs_.count(v), + buildErrorMessage( + "Ouput tensor has no corresponding bufs in the fuser.")); BufPtr buf = bufs_.at(v); // No shape info is present in the graph @@ -2812,23 +2957,27 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { throw malformed_input(msg); } - TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes()); + TORCH_INTERNAL_ASSERT( + tt->sizes().concrete_sizes(), + buildErrorMessage("Output shapes are unknown.")); auto sizes = *tt->sizes().concrete_sizes(); std::vector default_strides = TensorType::contiguousStridesOf(sizes); if (!tt->strides().concrete_sizes()) { - return new Tensor(buf, nullptr); + return Tensor(buf, nullptr); } - TORCH_INTERNAL_ASSERT(tt->strides().concrete_sizes()); + TORCH_INTERNAL_ASSERT( + tt->strides().concrete_sizes(), + buildErrorMessage("Output strides are unknown.")); const std::vector strides = *tt->strides().concrete_sizes(); // All Tensors in NNC are layed out in default, contiguous layout. // If the output is also default contiguous we don't need to do anything if (strides == default_strides) { - return new Tensor(buf, nullptr); + return Tensor(buf, nullptr); } // If the tensor is not dense or overlaps, we have // no way of matching the profiled striding if (!denseAndNonOverlapping(sizes, strides)) { - return new Tensor(buf, nullptr); + return Tensor(buf, nullptr); } auto dims = c10::fmap(sizesForValue(v)); @@ -2855,10 +3004,10 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { return Compute( "output_1", dims, [&](const std::vector& axes_input) { std::vector axes(axes_input.begin(), axes_input.end()); - auto absolute_position = IntImm::make(0); + auto absolute_position = ExprHandle(immLike(axes[0], 0)); for (size_t i = 0; i < axes.size(); ++i) { - absolute_position = - absolute_position + (IntImm::make(default_strides[i]) * axes[i]); + absolute_position = absolute_position + + (ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]); } std::vector sorted_stride_indices = reverse_sort_indices(strides); @@ -2866,10 +3015,11 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) { for (size_t stride_index : sorted_stride_indices) { auto stride = strides[stride_index]; auto size = sizes[stride_index]; - auto index = Div::make(absolute_position, IntImm::make(stride)); + auto index = absolute_position / + ExprHandle(immLike(absolute_position, stride)); if (size != 1) { - absolute_position = - Mod::make(absolute_position, IntImm::make(stride)); + absolute_position = absolute_position % + ExprHandle(immLike(absolute_position, stride)); } new_axes[stride_index] = index; } @@ -2891,11 +3041,11 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) { std::vector te_sizes; te_sizes.reserve(sizes.size()); for (auto s : sizes) { - te_sizes.push_back(IntImm::make(s)); + te_sizes.push_back(s); } BufPtr buf = alloc( - "const_" + v->debugName(), + "const_" + sanitizeName(v->debugName()), ExprHandleVectorToExprVector(te_sizes), ToDtype(static_cast(*tt->scalarType()))); @@ -2909,7 +3059,6 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) { } void TensorExprKernel::compile() { - KernelScope kernelScope(&kernelArena_); GRAPH_DUMP("TensorExprKernel graph:", graph_); device_ = *pickDeviceType(graph_); @@ -2922,8 +3071,9 @@ void TensorExprKernel::compile() { nInputs_ = graph_->inputs().size(); genInputDebugNames(); for (auto const& input : graph_->inputs()) { - if (Tensor* t = bindInput(input)) { - block->append_stmt(t->stmt()); + Tensor t = bindInput(input); + if (t.stmt()) { + block->append_stmt(t.stmt()); } } @@ -2937,10 +3087,9 @@ void TensorExprKernel::compile() { } else { for (auto const& output : n->outputs()) { if (output->hasUses()) { - Tensor* t = computeValue(output); - bufs_.emplace(output, t->buf()); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - block->append_stmt(t->stmt()); + Tensor t = computeValue(output); + bufs_.emplace(output, t.buf()); + block->append_stmt(t.stmt()); } } } @@ -2958,12 +3107,12 @@ void TensorExprKernel::compile() { // The "strided" tensor will be incorrect if used in NNC, // since NNC views it as contiguous. Only convert it to the right // strides at the end of the kernel (if already contiguous it's a no-op) - Tensor* properly_strided_output = convertOutputToCorrectStrides(output); - if (properly_strided_output->stmt()) { - block->append_stmt(properly_strided_output->stmt()); + Tensor properly_strided_output = convertOutputToCorrectStrides(output); + if (properly_strided_output.stmt()) { + block->append_stmt(properly_strided_output.stmt()); } // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - bufs_[output] = properly_strided_output->buf(); + bufs_[output] = properly_strided_output.buf(); const auto& tt = output->type()->expect(); auto sizes = *tt->sizes().concrete_sizes(); tensorOutputSizes_.push_back(sizes); @@ -3080,8 +3229,6 @@ StmtPtr TensorExprKernel::getCodeGenStmt() { } void TensorExprKernel::runKernel(Stack& stack) { - KernelScope kernelScope(&kernelArena_); - // Set up arguments (inputs, then outputs) for kernel call. auto inputs = last(stack, nInputs_); std::vector outputs; @@ -3101,8 +3248,6 @@ void TensorExprKernel::runKernel(Stack& stack) { void TensorExprKernel::runFast( const std::vector& inputs, const std::vector& outputs) { - KernelScope kernelScope(&kernelArena_); - std::vector args(inputs); args.reserve(inputs.size() + outputs.size() + constants_.size()); args.insert(args.end(), outputs.begin(), outputs.end()); diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 7b35e1e44905c..4b92b020fce31 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -19,7 +19,7 @@ template inline std::vector bufferSizes(const T& t) { std::vector sizes; for (size_t i = 0; i < t->ndim(); i++) { - sizes.push_back(to(t->dim(i))->value()); + sizes.push_back(*intValue(t->dim(i))); } return sizes; } @@ -47,7 +47,7 @@ using ArgValue = c10::variant< IntList, ArgNone>; -using NNCLoweringFunction = std::function&, const std::vector&, const c10::optional&, @@ -62,7 +62,7 @@ ExprHandle tensorOrConstant( const ArgValue& v, const std::vector& axes); -size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); +int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); ExprHandle broadcast(BufHandle b, const std::vector& axes); @@ -123,7 +123,7 @@ struct TensorInfo { c10::ScalarType dtype; }; -TORCH_API Tensor* computeOperandValue( +TORCH_API Tensor computeOperandValue( c10::Symbol op, const std::vector& inputs, const std::vector& outputShape, @@ -209,7 +209,7 @@ class TORCH_API TensorExprKernel { const torch::jit::Value* v, const std::vector& axes); - Tensor* computeValue(const torch::jit::Value* v); + Tensor computeValue(const torch::jit::Value* v); void bindConstant(const torch::jit::Value* v); @@ -222,9 +222,9 @@ class TORCH_API TensorExprKernel { std::vector& outputs); BackendType inferBackendTypeFromDevice(at::Device device); - Tensor* bindInput(const torch::jit::Value* input); + Tensor bindInput(const torch::jit::Value* input); - Tensor* convertOutputToCorrectStrides(torch::jit::Value* v); + Tensor convertOutputToCorrectStrides(torch::jit::Value* v); // Captures the information for reduction operation nodes. struct ReductionInfo { @@ -266,7 +266,6 @@ class TORCH_API TensorExprKernel { std::unordered_map input_name_map_; std::unique_ptr codegen_; at::Device device_ = at::kCPU; - KernelArena kernelArena_; std::shared_ptr graph_; Code code_; bool allow_fallback_{false}; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index eac1f82f25c4b..6c212e623df21 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -231,7 +231,7 @@ class LLVMCodeGenImpl : public IRVisitor { void visit(CompareSelectPtr v) override; #define IMM_VISIT_DECLARE(_1, Name) void visit(Name##ImmPtr v) override; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE); #undef IMM_VISIT_DECLARE void visit(CastPtr v) override; @@ -274,15 +274,24 @@ class LLVMCodeGenImpl : public IRVisitor { } }; -typedef void (*ParallelCallee)(int index, int8_t* packed_data); -void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) { +extern "C" { +typedef void (*ParallelCallee)(int64_t index, int8_t* packed_data); +void DispatchParallel( + int8_t* func, + int64_t start, + int64_t stop, + int8_t* packed_data) noexcept { // TODO: preserve the func type. - ParallelCallee callee = reinterpret_cast(func); - at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) { - for (int index = f_begin; index < f_end; index++) { - callee(index, packed_data); - } - }); + try { + ParallelCallee callee = reinterpret_cast(func); + at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) { + for (int64_t index = f_begin; index < f_end; index++) { + callee(index, packed_data); + } + }); + } catch (...) { + } +} } } // namespace tensorexpr @@ -415,9 +424,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( llvm::FunctionType* fntype = llvm::FunctionType::get(retTy, params, false); fn_ = llvm::Function::Create( fntype, llvm::Function::PrivateLinkage, "pytorch", module_.get()); - fn_->addAttribute( - llvm::AttributeList::AttrIndex::FunctionIndex, - llvm::Attribute::AlwaysInline); + fn_->addFnAttr(llvm::Attribute::AlwaysInline); for (const auto i : c10::irange(args.size())) { if (!args[i].isVar()) { fn_->addParamAttr(i, llvm::Attribute::NoAlias); @@ -488,12 +495,13 @@ class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander { if (v->op_type() == kTanh) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_tanh(v->param(0)->accept_mutator(this)).node(); + return fast_tanh(ExprHandle(v->param(0)->accept_mutator(this))).node(); } } else if (v->op_type() == kSigmoid) { ScalarType stype = v->dtype().scalar_type(); if (stype == ScalarType::Float) { - return fast_sigmoid(v->param(0)->accept_mutator(this)).node(); + return fast_sigmoid(ExprHandle(v->param(0)->accept_mutator(this))) + .node(); } } // TODO: fast exp @@ -529,10 +537,6 @@ void LLVMCodeGenImpl::emitKernel( irb_.CreateRet(value_); - if (llvm::verifyFunction(*fn_, &llvm::outs())) { - throw std::runtime_error("Function verification failed"); - } - // print graph debug info before optimization llvm::SmallVector asmBuffer; llvm::raw_svector_ostream asmStream(asmBuffer); @@ -542,6 +546,10 @@ void LLVMCodeGenImpl::emitKernel( GRAPH_DEBUG( "\nLLVM module before optimizations\n\n", asmStream.str().str(), "\n"); + if (llvm::verifyFunction(*fn_, &llvm::outs())) { + throw std::runtime_error("Function verification failed"); + } + optimize(*module_); asmBuffer.set_size(0); @@ -894,6 +902,12 @@ void LLVMCodeGenImpl::visit(HalfImmPtr v) { value_ = llvm::ConstantFP::get(HalfTy_, v->value()); } +void LLVMCodeGenImpl::visit(BFloat16ImmPtr v) { + TORCH_INTERNAL_ASSERT( + false, + buildErrorMessage("Fuser's LLVM codegen does not support bfloat16")); +} + void LLVMCodeGenImpl::visit(BoolImmPtr v) { value_ = llvm::ConstantInt::get(BoolTy_, v->value()); } @@ -1136,8 +1150,8 @@ void LLVMCodeGenImpl::visit(LoadPtr v) { // Handle the case where the load is contiguous and unmasked efficiently auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto stride_imm = to(idx_ramp->stride()); - if (stride_imm && stride_imm->value() == 1) { + auto stride_imm = intValue(idx_ramp->stride()); + if (stride_imm && *stride_imm == 1) { v->base_handle()->accept(this); auto base = this->value_; idx_ramp->base()->accept(this); @@ -1248,7 +1262,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { // Create the new body closure code. auto func_type = - llvm::FunctionType::get(VoidTy_, {IntTy_, Int8PtrTy_}, false); + llvm::FunctionType::get(VoidTy_, {LongTy_, Int8PtrTy_}, false); llvm::Function* func = llvm::Function::Create( func_type, llvm::Function::PrivateLinkage, "func", module_.get()); auto func_body = llvm::BasicBlock::Create(getContext(), "func_body", func); @@ -1260,6 +1274,10 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { packed_func_args_raw, packed_caller_args->getType()); // Unpack the arguments from the opaque buffer. + if (v->var()->dtype().scalar_type() != c10::kLong) { + index = irb_.CreateIntCast( + index, dtypeToLLVM(v->var()->dtype()), v->var()->dtype().is_signed()); + } body_closure_args = unpackFuncArgs(packed_func_args, body_arg_vars.size()); // Set the codegen to the new func. // TODO: this should be replaced by RAII wrappers. @@ -1282,11 +1300,14 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) { irb_.CreatePointerCast(packed_caller_args, Int8PtrTy_); llvm::Value* func_value = irb_.CreatePointerCast(func, Int8PtrTy_); llvm::FunctionType* dispatcher_fntype = llvm::FunctionType::get( - VoidTy_, {Int8PtrTy_, IntTy_, IntTy_, Int8PtrTy_}, false); + VoidTy_, {Int8PtrTy_, LongTy_, LongTy_, Int8PtrTy_}, false); FunctionCallee dispatcher_callee = module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype); llvm::Function* dispatcher = llvm::cast(dispatcher_callee.getCallee()); + dispatcher->addFnAttr(llvm::Attribute::NoUnwind); + start = irb_.CreateIntCast(start, LongTy_, true); + stop = irb_.CreateIntCast(stop, LongTy_, true); irb_.CreateCall( dispatcher, {func_value, start, stop, packed_caller_args_ptr}); value_ = llvm::ConstantInt::get(IntTy_, 0); @@ -1311,7 +1332,7 @@ void LLVMCodeGenImpl::visit(ForPtr v) { irb_.SetInsertPoint(condBlock); // Set up phi node for index variable. - auto idx = irb_.CreatePHI(IntTy_, 2); + auto idx = irb_.CreatePHI(start->getType(), 2); idx->addIncoming(start, preheader); if (!varToVal_.count(v->var())) { varToVal_.emplace(v->var(), idx); @@ -1336,7 +1357,8 @@ void LLVMCodeGenImpl::visit(ForPtr v) { body = irb_.GetInsertBlock(); // Increment the index variable and branch back to loop test. - auto inc = irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(IntTy_, 1)); + auto inc = + irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(start->getType(), 1)); irb_.CreateBr(condBlock); idx->addIncoming(inc, body); @@ -1421,8 +1443,8 @@ void LLVMCodeGenImpl::visit(StorePtr v) { // Handle the case where the store is contiguous and unmasked efficiently auto idx_ramp = to(v->flat_index()); if (idx_ramp) { - auto stride_imm = to(idx_ramp->stride()); - if (stride_imm && stride_imm->value() == 1) { + auto stride_imm = intValue(idx_ramp->stride()); + if (stride_imm && *stride_imm == 1) { idx_ramp->base()->accept(this); auto first_idx = value_; @@ -1515,7 +1537,10 @@ void LLVMCodeGenImpl::emitIsNan(IntrinsicsPtr v) { if (!v->param(0)->dtype().is_floating_point()) { value_ = toVec(llvm::ConstantInt::get(dstType, 0), v->dtype().lanes()); } else { - TORCH_INTERNAL_ASSERT(v->dtype().scalar_type() == ScalarType::Int); + TORCH_INTERNAL_ASSERT( + v->dtype().scalar_type() == ScalarType::Int, + buildErrorMessage( + "Unexpected non-Int dtype of Intrinsics' result value in the fuser.")); auto is_nan = irb_.CreateFCmpUNO( value_, llvm::ConstantFP::get(value_->getType(), 0.)); if (v->dtype().lanes() > 1) { @@ -1742,11 +1767,11 @@ void LLVMCodeGenImpl::visit(IntrinsicsPtr v) { } else { TORCH_INTERNAL_ASSERT( false, - v, - "Unimplemented lowering:", - v->op_type(), - " for input of dtype", - v->dtype().scalar_dtype()); + buildErrorMessage( + std::string("Unimplemented lowering for intrinsic '") + + std::to_string(v->op_type()) + "' for input of dtype " + + std::to_string(v->dtype().scalar_dtype()) + + " in LLVM codegen of the fuser.")); } std::vector params; diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index 30ad5317a1b3c..a837899cdce1d 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -17,7 +17,13 @@ namespace torch { namespace jit { namespace tensorexpr { -void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data); +extern "C" { +void DispatchParallel( + int8_t* func, + int64_t start, + int64_t stop, + int8_t* packed_data) noexcept; +} inline std::string formatError(llvm::Error&& err, const char* msg) { static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT"; diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index ea6f09349e444..e67d094065d1a 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -46,14 +47,14 @@ LoopNest::LoopNest(StmtPtr stmt, std::unordered_set output_bufs) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) LoopNest::LoopNest( - const std::vector& output_tensors, - const std::vector& tensors_to_compute) { + const std::vector& output_tensors, + const std::vector& tensors_to_compute) { initialize(output_tensors, tensors_to_compute); verify(root_stmt_); } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -LoopNest::LoopNest(const std::vector& output_tensors) { +LoopNest::LoopNest(const std::vector& output_tensors) { initialize(output_tensors, output_tensors); verify(root_stmt_); } @@ -108,12 +109,13 @@ class IndexFlattener : public IRMutator { ExprPtr value = v->value(); ExprPtr new_value = value->accept_mutator(this); if (v->indices().size() == 1 && value == new_value) { - return (StmtPtr)v; + return v; } - return alloc( - v->buf(), - std::vector({flatten_index(v->buf()->dims(), v->indices())}), - new_value); + std::vector indices = { + flatten_index(v->buf()->dims(), v->indices())}; + v->set_indices(indices); + v->set_value(new_value); + return v; } }; @@ -125,8 +127,8 @@ class Vectorizer : public IRMutator { ExprPtr start = v->start(); ExprPtr stop = v->stop(); - IntImmPtr start_imm = to(start); - IntImmPtr stop_imm = to(stop); + auto start_imm = intValue(start); + auto stop_imm = intValue(stop); if (!start_imm) { throw std::runtime_error( "Can't vectorize due to non-constant loop start!"); @@ -138,8 +140,8 @@ class Vectorizer : public IRMutator { } var_ = var; - start_ = start_imm; - lanes_ = stop_imm->value(); + start_ = immLike(start, *start_imm); + lanes_ = *stop_imm; StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { @@ -177,6 +179,13 @@ class Vectorizer : public IRMutator { }); } + ExprPtr mutate(ModPtr v) override { + std::vector inputs = {v->lhs(), v->rhs()}; + return try_vectorize(v, inputs, [&]() { + return ExprHandle(inputs[0]) % ExprHandle(inputs[1]); + }); + } + ExprPtr mutate(AndPtr v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { @@ -484,15 +493,15 @@ bool LoopNest::vectorize(ForPtr f) { } void LoopNest::initialize( - const std::vector& output_tensors, - const std::vector& tensors_to_compute) { + const std::vector& output_tensors, + const std::vector& tensors_to_compute) { for (auto t : output_tensors) { - output_bufs_.insert(t->buf()); + output_bufs_.insert(t.buf()); } std::vector loops; - for (Tensor* t : tensors_to_compute) { - StmtPtr loop = t->stmt(); + for (Tensor t : tensors_to_compute) { + StmtPtr loop = t.stmt(); if (loop->get_parent()) { std::cerr << "Error: creating a loopnest from already used Tensors\n"; loops = {}; @@ -522,12 +531,13 @@ class FunctionInliner : public IRMutator { if (auto index_var = to(i)) { index_vars_.insert(index_var); producer_index_vars_.push_back(index_var); - } else if (to(i) != nullptr) { + } else if (intValue(i)) { // If the index can be a constant, then that dimension must have size 1 // (since we don't support in-place writes). Resolves issue 52581. TORCH_INTERNAL_ASSERT( - to(i)->value() == 0, - "Constant index impression should always be zero"); + *intValue(i) == 0, + buildErrorMessage( + "Unexpected non-zero constant index in inlined buffer in the fuser.")); producer_index_vars_.push_back(nullptr); } else { throw std::logic_error("cannot inline Buf with compound indices"); @@ -538,32 +548,43 @@ class FunctionInliner : public IRMutator { private: ExprPtr mutate_loads(BufPtr buf, std::vector dims) { std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); + TORCH_INTERNAL_ASSERT( + buf->ndim() == producer_index_vars_.size(), + buildErrorMessage( + "Dimensions of producer and consumer expressions do not match in inliner in the fuser.")); for (const auto i : c10::irange(buf->ndim())) { VarPtr func_callee_arg = producer_index_vars_.at(i); ExprPtr func_caller_param = dims.at(i); if (func_callee_arg == nullptr) { TORCH_INTERNAL_ASSERT( - to(func_caller_param) != nullptr && - to(func_caller_param)->value() == 0, - "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0"); + intValue(func_caller_param) && *intValue(func_caller_param) == 0, + buildErrorMessage( + "We are implicitly assuming that if you have an index of 0, that must also be inlined into an index of 0")); continue; } if (func_callee_arg == nullptr) continue; auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { - throw std::runtime_error( + throw std::logic_error( "Duplicated variables: " + func_callee_arg->name_hint()); } // Add a mapping for each function parameter to it's source name. inline_mapping_[func_callee_arg] = func_caller_param; + GRAPH_DEBUG( + "ComputeInline: Inline mapping: ", + std::to_string(func_callee_arg), + " -> ", + std::to_string(func_caller_param)); index_vars.push_back(func_callee_arg); } // Call the actual replacement. ExprPtr body = producer_->value(); + GRAPH_DEBUG("ComputeInline: Before rewriting body: ", std::to_string(body)); ExprPtr result = Expr::clone(body)->accept_mutator(this); + GRAPH_DEBUG( + "ComputeInline: After rewriting body: ", std::to_string(result)); // Remove the mappings we created for this function parameters. for (auto v : index_vars) { @@ -575,6 +596,7 @@ class FunctionInliner : public IRMutator { } } } + GRAPH_DEBUG("ComputeInline: Inline mapping: erasing", std::to_string(v)); inline_mapping_.erase(v); } return result; @@ -586,10 +608,10 @@ class FunctionInliner : public IRMutator { return IRMutator::mutate(v); } - if (v->indices().size() != buf->ndim()) { - throw malformed_input( - "Placeholder indexed access is inconsistent with its rank", v); - } + TORCH_INTERNAL_ASSERT( + v->indices().size() == buf->ndim(), + buildErrorMessage( + "Number of indices doesn't match buf rank in the fuser.")); return mutate_loads(buf, v->indices()); } @@ -617,6 +639,8 @@ class FunctionInliner : public IRMutator { const std::string& name = buf_->name_hint(); VarPtr new_var = alloc(name, v->dtype()); random_bindings_[alloc(new_var, v)] = index_vars_; + GRAPH_DEBUG( + "ComputeInline: created random bindings for ", std::to_string(new_var)); return new_var; } @@ -627,7 +651,10 @@ class FunctionInliner : public IRMutator { if (v == producer_ && !outputs_.count(buf_)) { in_producer_ = true; producer_ = to(IRMutator::mutate(v)); - TORCH_INTERNAL_ASSERT(producer_ != nullptr); + TORCH_INTERNAL_ASSERT( + producer_, + buildErrorMessage( + "Producer statement for output buf should remain non-null in the fuser")); in_producer_ = false; return nullptr; } else { @@ -729,8 +756,12 @@ bool LoopNest::computeInline(BufPtr b) { } } - TORCH_INTERNAL_ASSERT(relevant_store); + TORCH_INTERNAL_ASSERT( + relevant_store, + buildErrorMessage( + "Cannot find a relevant store to inline a buf in the fuser.")); + GRAPH_DEBUG("ComputeInline: Def: ", std::to_string(relevant_store)); FunctionInliner inliner(relevant_store, output_bufs_); root_stmt_ = root_stmt_->accept_mutator(&inliner); @@ -752,7 +783,11 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { auto input_bufs = getInputBufs(); for (auto buf : intermediate_bufs) { - TORCH_INTERNAL_ASSERT(buf_load_store_uses.count(buf)); + TORCH_INTERNAL_ASSERT( + buf_load_store_uses.count(buf), + buildErrorMessage( + "Could not find uses of buf '" + buf->name_hint() + + "' in the fuser.")); std::vector& uses = buf_load_store_uses[buf]; auto stores = c10::filter( uses, [](const BufLoadOrStoreUse& use) { return use.isStore; }); @@ -769,7 +804,11 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { } } else { // If S is not a store, it must be an ExternalCall. - TORCH_INTERNAL_ASSERT(to(stores[0].s)); + TORCH_INTERNAL_ASSERT( + to(stores[0].s), + buildErrorMessage( + "Expected stmt: " + std::to_string(stores[0].s) + + "\nto be either a Store or an ExternalCall in the fuser.")); } } @@ -1119,7 +1158,7 @@ bool LoopNest::optimizeConditionals() { // only include the RHS of the conditions in the if-then-else expressions // we need to start with `0` which is the initial bound, given that we // only handle normalized loops (check for this is done below). - std::vector comp_values = {alloc(0)}; + std::vector comp_values; std::vector sub_exprs; auto ifthenelse_exprs = NodeFinder::find(store); if (ifthenelse_exprs.empty()) { @@ -1134,6 +1173,11 @@ bool LoopNest::optimizeConditionals() { ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) { continue; } + TORCH_INTERNAL_ASSERT( + comp_values.size() >= 1, + buildErrorMessage( + "Expected at least one expression in optimizeConditional in the fuser.")); + comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0)); auto fors = getLoopStmtsFor(store); if (cond_var != fors.back()->var()) { @@ -1269,10 +1313,10 @@ void LoopNest::vectorizeInnerLoops() { } void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; + if (intValue(f->start()) && intValue(f->stop())) { + auto start_val = *intValue(f->start()); + auto stop_val = *intValue(f->stop()); + auto size_val = stop_val - start_val; if (factor >= size_val) { *head = f; *tail = nullptr; @@ -1290,13 +1334,12 @@ void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } ExprPtr head_end = alloc( - alloc(f->start(), alloc(factor)), f->stop(), true); + alloc(f->start(), immLike(f->stop(), factor)), f->stop(), true); *head = alloc(f->var(), f->start(), head_end, Stmt::clone(f->body())); - *tail = alloc( - f->var(), head_end, f->stop(), Stmt::clone(f->body()), f->loop_options()); + p->insert_stmt_before(*head, f); - p->replace_stmt(f, *head); - p->insert_stmt_after(*tail, *head); + f->set_start(head_end); + *tail = f; if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { @@ -1310,10 +1353,10 @@ void LoopNest::sliceHead(ForPtr f, int factor) { } void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; + if (intValue(f->start()) && intValue(f->stop())) { + auto start_val = *intValue(f->start()); + auto stop_val = *intValue(f->stop()); + auto size_val = stop_val - start_val; if (factor >= size_val) { *head = nullptr; *tail = f; @@ -1331,17 +1374,12 @@ void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } ExprPtr tail_start = alloc( - f->start(), alloc(f->stop(), alloc(factor)), true); - *head = alloc( - f->var(), - f->start(), - tail_start, - Stmt::clone(f->body()), - f->loop_options()); + f->start(), alloc(f->stop(), immLike(f->stop(), factor)), true); *tail = alloc(f->var(), tail_start, f->stop(), Stmt::clone(f->body())); + p->insert_stmt_after(*tail, f); - p->replace_stmt(f, *head); - p->insert_stmt_after(*tail, *head); + f->set_stop(tail_start); + *head = f; if (f->loop_options().is_gpu_block_index() || f->loop_options().is_gpu_thread_index()) { @@ -1375,17 +1413,17 @@ void LoopNest::splitWithTail( } bool tail_is_needed = true; - if (to(f->start()) && to(f->stop())) { - int start_val = to(f->start())->value(); - int stop_val = to(f->stop())->value(); - int size_val = stop_val - start_val; - int tail_size = size_val % factor; + if (intValue(f->start()) && intValue(f->stop())) { + auto const start_val = *intValue(f->start()); + auto const stop_val = *intValue(f->stop()); + auto const size_val = stop_val - start_val; + auto const tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } - IntImmPtr factor_expr = alloc(factor); + ExprPtr factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); ExprPtr split_count = alloc
(size, factor_expr); ExprPtr tail_size = alloc(size, factor_expr); @@ -1408,7 +1446,7 @@ void LoopNest::splitWithTail( StmtPtr body_tail = SubstituteInClone(f->body(), {{f->var(), combined_index2}}); - *tail = alloc(i_tail, alloc(0), tail_size, body_tail); + *tail = alloc(i_tail, immLike(tail_size, 0), tail_size, body_tail); p->insert_stmt_after(*tail, f); } else { @@ -1418,10 +1456,11 @@ void LoopNest::splitWithTail( StmtPtr body_inner = Substitute(f->removeBody(), {{f->var(), combined_index1}}); - *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); + *inner = + alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(alloc(0)); + f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } @@ -1443,20 +1482,20 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { ExprPtr start = IRSimplifier::simplify(f->start()); ExprPtr stop = IRSimplifier::simplify(f->stop()); if (start->isConstant() && stop->isConstant()) { - int start_val = immediateAs(start); - int stop_val = immediateAs(stop); - int size_val = stop_val - start_val; - int tail_size = size_val % factor; + auto start_val = *intValue(start); + auto stop_val = *intValue(stop); + auto size_val = stop_val - start_val; + auto tail_size = size_val % factor; if (tail_size == 0) { tail_is_needed = false; } } - IntImmPtr factor_expr = alloc(factor); + auto factor_expr = immLike(f->stop(), factor); ExprPtr size = alloc(f->stop(), f->start()); // split_count = (size + factor - 1) / factor ExprPtr split_count = alloc
( - alloc(alloc(size, factor_expr), alloc(1)), factor_expr); + alloc(alloc(size, factor_expr), immLike(size, 1)), factor_expr); const std::string& loop_var_name = f->var()->name_hint(); Dtype loop_var_dtype = f->var()->dtype(); @@ -1472,8 +1511,8 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { // TODO: is it ok that we're doing it eagerly? In the other implementation we // are only materializing predicates at the last, lowering, step. if (tail_is_needed) { - IntImmPtr start = to(f->start()); - if (!start || start->value() != 0) { + auto start = intValue(f->start()); + if (!start || *start != 0) { throw unimplemented_lowering(); } @@ -1484,10 +1523,11 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { } body_inner = Substitute(body_inner, {{f->var(), combined_index}}); - *inner = alloc(i_inner, alloc(0), factor_expr, body_inner); + *inner = + alloc(i_inner, immLike(factor_expr, 0), factor_expr, body_inner); // The input loop `f` will be the outer loop after split. f->set_var(i_outer); - f->set_start(alloc(0)); + f->set_start(immLike(split_count, 0)); f->set_stop(split_count); f->set_body(*inner); } @@ -1495,7 +1535,10 @@ void LoopNest::splitWithMask(ForPtr f, int factor, ForPtr* inner) { std::vector LoopNest::distributeLoop( ForPtr loop, const std::unordered_set& pivots) { - TORCH_INTERNAL_ASSERT(loop); + TORCH_INTERNAL_ASSERT( + loop, + buildErrorMessage( + "Expected non-null loop in distributeLoop in the fuser.")); auto root = loop->get_parent(); if (root == nullptr) { throw malformed_input("Loop without parent: ", loop); @@ -1740,7 +1783,10 @@ bool LoopNest::unsafeFuseLoops( break; } } - TORCH_INTERNAL_ASSERT(it != root_block->end()); + TORCH_INTERNAL_ASSERT( + it != root_block->end(), + buildErrorMessage( + "Could not find the given loop in the root stmt in unsafeFuseLoop the fuser.")); for (auto l : loops) { if (*it != l) { return false; @@ -2014,7 +2060,10 @@ std::vector LoopNest::reorder( parent->replace_stmt(loops.front(), empty_block); for (size_t i = 1; i < loops.size(); ++i) { auto block = to(loops[i]->get_parent()); - TORCH_INTERNAL_ASSERT(block); + TORCH_INTERNAL_ASSERT( + block, + buildErrorMessage( + "Expected parent stmt to be a non-null Block in reorder transformation the fuser.")); block->remove_stmt(loops[i]); } @@ -2162,7 +2211,7 @@ bool LoopNest::normalize(ForPtr f) { {{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}}); f->set_body(IRSimplifier::simplify(for_body_normalized)); f->set_stop(IRSimplifier::simplify(alloc(f->stop(), f->start()))); - f->set_start(alloc(0)); + f->set_start(immLike(f->stop(), 0)); return true; } @@ -2173,9 +2222,13 @@ std::vector LoopNest::getLoopStmtsInLoopNest(ForPtr f, size_t num) { ForPtr curr_for = f; loops[0] = curr_for; for (size_t i = 1; i < num; ++i) { - TORCH_INTERNAL_ASSERT(curr_for->body()->nstmts() == 1); + TORCH_INTERNAL_ASSERT( + curr_for->body()->nstmts() == 1, + buildErrorMessage("Expected a single stmt in the loop body.")); curr_for = to(curr_for->body()->front()); - TORCH_INTERNAL_ASSERT(curr_for); + TORCH_INTERNAL_ASSERT( + curr_for, + buildErrorMessage("Expected the only child stmt to be a For loop.")); loops[i] = curr_for; } return loops; @@ -2227,7 +2280,7 @@ bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { normalized_loops[0]->var()->name_hint() + "_flat", normalized_loops[0]->var()->dtype()); VarMapping var_mapping; - ExprPtr stop = alloc(1); + ExprPtr stop = immLike(flat_var, 1); for (size_t i = 0; i < normalized_loops.size(); ++i) { size_t idx = normalized_loops.size() - i - 1; auto curr_loop = normalized_loops[idx]; @@ -2240,7 +2293,7 @@ bool LoopNest::flatten(const std::vector& loops, ForPtr* flattened) { Substitute(normalized_loops.back()->removeBody(), var_mapping); normalized_loops.front()->set_var(flat_var); - normalized_loops.front()->set_start(alloc(0)); + normalized_loops.front()->set_start(immLike(stop, 0)); normalized_loops.front()->set_stop(stop); normalized_loops.front()->set_body(flattened_body); *flattened = normalized_loops.front(); @@ -2285,7 +2338,10 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { // Find the parent common to all the buffer accesses. BlockPtr parent = to(writes.front()->get_parent()); - TORCH_INTERNAL_ASSERT(parent); + TORCH_INTERNAL_ASSERT( + parent, + buildErrorMessage( + "Expected parent stmt to be a non-null block in compressBuffer in the fuser.")); for (auto w : writes) { parent = Block::getSharedParent(parent, w); } @@ -2307,7 +2363,10 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { // Vector to indicate which dimensions could be compressed away. std::vector dims(buf->dims().size(), true); auto check_indices = [&](const std::vector& indices) { - TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); + TORCH_INTERNAL_ASSERT( + indices.size() == dims.size(), + buildErrorMessage( + "Expected ranks to match in compressBuffer in the fuser.")); for (size_t i = 0; i < indices.size(); ++i) { auto index_vars = NodeFinder::find(indices[i]); for (auto iv : index_vars) { @@ -2342,18 +2401,21 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { std::vector new_dims(buf->dims()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_dims[i] = alloc(1); + new_dims[i] = immLike(buf->dims()[i], 1); } } buf->set_dims(new_dims); // Modify all access to reflect the removed dims. auto get_new_indices = [&](const std::vector& indices) { - TORCH_INTERNAL_ASSERT(indices.size() == dims.size()); + TORCH_INTERNAL_ASSERT( + indices.size() == dims.size(), + buildErrorMessage( + "Expected ranks to match in compressBuffer in the fuser.")); std::vector new_indices(indices); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i]) { - new_indices[i] = alloc(0); + new_indices[i] = immLike(indices[i], 0); } } return new_indices; @@ -2372,11 +2434,11 @@ void LoopNest::compressBuffer(BufPtr buf, StmtPtr stmt) { void LoopNest::compressAllBuffers(StmtPtr stmt) { for (auto buf : BufFinder::find(stmt)) { - compressBuffer(const_cast(buf), stmt); + compressBuffer(buf, stmt); } } -std::vector LoopNest::getLoopStmtsFor(Tensor* t) const { +std::vector LoopNest::getLoopStmtsFor(Tensor t) const { StmtPtr cur_stmt = getLoopBodyFor(t); return getLoopStmtsFor(cur_stmt); } @@ -2399,8 +2461,8 @@ std::vector LoopNest::getLoopStmtsFor(StmtPtr s) const { return result; } -StmtPtr LoopNest::getLoopBodyFor(Tensor* t) const { - return getLoopBodyFor(t->buf()); +StmtPtr LoopNest::getLoopBodyFor(Tensor t) const { + return getLoopBodyFor(t.buf()); } StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const { @@ -2555,15 +2617,19 @@ class CacheReplacer : public IRMutator { // Map indices to call-parameters. std::vector newIndices; - TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + TORCH_INTERNAL_ASSERT( + offsets_.size() == v->indices().size(), + buildErrorMessage( + "Expected ranks to match in CacheReplacer in the fuser.")); for (size_t i = 0; i < v->indices().size(); ++i) { ExprPtr index = v->indices()[i]->accept_mutator(this); ExprPtr offset = offsets_[i]; ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - - return alloc(cache_, newIndices); + v->set_buf(cache_); + v->set_indices(newIndices); + return v; } StmtPtr mutate(StorePtr v) override { @@ -2576,15 +2642,20 @@ class CacheReplacer : public IRMutator { // Map indices to call-parameters. std::vector newIndices; - TORCH_INTERNAL_ASSERT(offsets_.size() == v->indices().size()); + TORCH_INTERNAL_ASSERT( + offsets_.size() == v->indices().size(), + buildErrorMessage( + "Expected ranks to match in CacheReplacer in the fuser.")); for (size_t i = 0; i < v->indices().size(); ++i) { ExprPtr index = v->indices()[i]->accept_mutator(this); ExprPtr offset = offsets_[i]; ExprPtr sub = IRSimplifier::simplify(alloc(index, offset)); newIndices.push_back(sub); } - - return alloc(cache_, newIndices, newValue); + v->set_buf(cache_); + v->set_indices(newIndices); + v->set_value(newValue); + return v; } BufPtr buf_; @@ -2622,7 +2693,10 @@ LoopNest::AccessResult LoopNest::cacheAccesses( return {nullptr, nullptr}; } - TORCH_INTERNAL_ASSERT(bounds_it->second.size() == 1); + TORCH_INTERNAL_ASSERT( + bounds_it->second.size() == 1, + buildErrorMessage( + "Unexpected number of bound info entries in cacheAccesses in the fuser.")); TensorAccessBoundsInfo& info = bounds_it->second[0]; bool hasReads = info.kind == kLoad || info.kind == kMutate; bool hasWrites = info.kind == kStore || info.kind == kMutate; @@ -2634,12 +2708,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // Determine the size of the cache, and create a loop var for each dimension. for (size_t i = 0; i < info.start.size(); ++i) { - ExprPtr dim = IRSimplifier::simplify( - alloc(alloc(info.stop[i], info.start[i]), alloc(1))); + ExprPtr dim = IRSimplifier::simplify(alloc( + alloc(info.stop[i], info.start[i]), immLike(info.stop[i], 1))); tmp_dims.push_back(dim); - new_loop_vars.push_back(alloc(var_names[i % var_names.size()], kInt)); + new_loop_vars.push_back( + alloc(var_names[i % var_names.size()], info.stop[i]->dtype())); new_loop_vars_expr.push_back(new_loop_vars[i]); } @@ -2656,21 +2731,13 @@ LoopNest::AccessResult LoopNest::cacheAccesses( // Replace acceses to the producer in the consumer with the cache. CacheReplacer replacer(producer, tmp_buf, info.start); - // TODO: Can we reuse 'consumer' below without cloning? - StmtPtr new_consumer = - IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer)); + consumer->accept_mutator(&replacer); // replace the old consumer with the replaced consumer. - BlockPtr consumer_block = nullptr; + BlockPtr consumer_block = to(consumer); + BlockPtr parent_block = to(consumer->get_parent()); // if the consumer is a block, we should mutate it in place. - if ((consumer_block = to(consumer))) { - consumer_block->clear(); - consumer_block->append_stmt(new_consumer); - } else { - consumer_block = to(consumer->get_parent()); - assert(consumer_block); - consumer_block->replace_stmt(consumer, new_consumer); - } + bool is_block = consumer_block != nullptr; // If there's a reduction and we are operating on the reduce axis, we need to // initialize the cache with 0s. Also, we can't just write the result straight @@ -2698,11 +2765,15 @@ LoopNest::AccessResult LoopNest::cacheAccesses( tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0)); for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { - tmp_init = - alloc(new_loop_vars[i], alloc(0), tmp_dims[i], tmp_init); + tmp_init = alloc( + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init); } - consumer_block->insert_stmt_before(tmp_init, new_consumer); + if (is_block) { + consumer_block->prepend_stmt(tmp_init); + } else { + parent_block->insert_stmt_before(tmp_init, consumer); + } // Reduce back to the original buffer: StmtPtr tmp_store = alloc( @@ -2716,12 +2787,16 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_after(tmp_store, new_consumer); + if (is_block) { + consumer_block->append_stmt(tmp_store); + } else { + parent_block->insert_stmt_after(tmp_store, consumer); + } - return std::make_pair(tmp_buf, new_consumer); + return std::make_pair(tmp_buf, consumer); } if (hasReads) { @@ -2731,10 +2806,14 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_before(tmp_store, new_consumer); + if (is_block) { + consumer_block->prepend_stmt(tmp_store); + } else { + parent_block->insert_stmt_before(tmp_store, consumer); + } } if (hasWrites) { @@ -2744,13 +2823,17 @@ LoopNest::AccessResult LoopNest::cacheAccesses( for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) { tmp_store = alloc( - new_loop_vars[i], alloc(0), tmp_dims[i], tmp_store); + new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store); } - consumer_block->insert_stmt_after(tmp_store, new_consumer); + if (is_block) { + consumer_block->append_stmt(tmp_store); + } else { + parent_block->insert_stmt_after(tmp_store, consumer); + } } - return std::make_pair(tmp_buf, new_consumer); + return std::make_pair(tmp_buf, consumer); } /* @@ -2888,7 +2971,8 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { std::vector temp_indices(dims.size()); for (const auto i : c10::irange(dims.size())) { // TODO: Use name-hint of the producer indices instead of 'idx' - temp_indices[i] = alloc(std::string("idx") + c10::to_string(i), kInt); + temp_indices[i] = + alloc(std::string("idx") + c10::to_string(i), dims[i]->dtype()); } // Prepare substitute rules for constructing the temp statement from the prod @@ -2929,7 +3013,10 @@ void LoopNest::computeAt(StmtPtr s, ForPtr f) { // dimensions in reversed order. size_t dim_idx = dims.size() - 1 - i; bd = alloc( - to(temp_indices[dim_idx]), alloc(0), dims[dim_idx], bd); + to(temp_indices[dim_idx]), + immLike(dims[dim_idx], 0), + dims[dim_idx], + bd); } // Add constructed stmts to the consumer loop @@ -2964,7 +3051,10 @@ class RfactorStoreRewriter : public IRMutator { return IRMutator::mutate(v); } - TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + TORCH_INTERNAL_ASSERT( + old_indices_.size() == v->indices().size(), + buildErrorMessage( + "Expected ranks to match in RfactorStoreRewriter in the fuser.")); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { @@ -2998,7 +3088,10 @@ class RfactorStoreRewriter : public IRMutator { return IRMutator::mutate(v); } - TORCH_INTERNAL_ASSERT(old_indices_.size() == v->indices().size()); + TORCH_INTERNAL_ASSERT( + old_indices_.size() == v->indices().size(), + buildErrorMessage( + "Expected ranks to match in RfactorStoreRewriter in the fuser.")); bool equal_indices = true; for (size_t i = 0; i < v->indices().size(); ++i) { @@ -3107,7 +3200,10 @@ bool LoopNest::rfactor( // X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}], // reduce_axis={reduction_var}) BlockPtr b = outer_reduction_for->body(); - TORCH_INTERNAL_ASSERT(b->nstmts() == 1); + TORCH_INTERNAL_ASSERT( + b->nstmts() == 1, + buildErrorMessage( + "Expected to have a single stmt in the block in rfactor transformation in the fuser.")); StmtPtr first_reduction_loop = b->stmts().front(); auto rfac_buf_indices = orig_buf_indices; rfac_buf_indices.emplace_back(reduction_var); diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index c8cf2d8553d2d..42f072d2da7d8 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -27,11 +27,11 @@ class TORCH_API LoopNest { public: // A constructor for building a LoopNest from a list of Tensors LoopNest( - const std::vector& output_tensors, - const std::vector& tensors_to_compute); + const std::vector& output_tensors, + const std::vector& tensors_to_compute); // A convenience constructor for the case when all tensors are output tensors - LoopNest(const std::vector& output_tensors); + LoopNest(const std::vector& output_tensors); // A constructor for building a LoopNest from an Stmt and a list of output // buffers. @@ -45,10 +45,10 @@ class TORCH_API LoopNest { return root_stmt_; } - std::vector getLoopStmtsFor(Tensor*) const; + std::vector getLoopStmtsFor(Tensor) const; std::vector getLoopStmtsFor(BufPtr) const; std::vector getLoopStmtsFor(StmtPtr) const; - StmtPtr getLoopBodyFor(Tensor*) const; + StmtPtr getLoopBodyFor(Tensor) const; StmtPtr getLoopBodyFor(BufPtr) const; // Returns the For stmt indexed by 'indices' in the 'root' For stmt. @@ -547,8 +547,8 @@ class TORCH_API LoopNest { private: void initialize( - const std::vector& output_tensors, - const std::vector& tensors_to_compute); + const std::vector& output_tensors, + const std::vector& tensors_to_compute); StmtPtr insertAllocFree(StmtPtr stmt); const std::unordered_set getIntermediateBufs() const; diff --git a/torch/csrc/jit/tensorexpr/mem_arena.cpp b/torch/csrc/jit/tensorexpr/mem_arena.cpp deleted file mode 100644 index 1769563424f5c..0000000000000 --- a/torch/csrc/jit/tensorexpr/mem_arena.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -namespace { -// Define in an anonymous namespace to hide this symbol from other compilation -// units -thread_local KernelArena* current_arena = nullptr; -} // namespace - -KernelArena::~KernelArena() { - for (KernelScopedObject* p : kernel_objects_) { - delete p; - } -} - -KernelScopedObject::KernelScopedObject() { - KernelArena* kernel = KernelArena::GetCurrentKernelArena(); - if (kernel == nullptr) { - throw std::runtime_error( - "KernelScope() must be constructed before calling this"); - } - kernel->kernel_objects_.push_back(this); -} - -void KernelArena::SetCurrentKernelArena(KernelArena* new_kernel_arena) { - current_arena = new_kernel_arena; -} - -KernelArena* KernelArena::GetCurrentKernelArena() { - return current_arena; -} - -KernelScope::KernelScope() - : kernel_arena_(new KernelArena()), - old_kernel_arena_(KernelArena::GetCurrentKernelArena()), - owning_(true) { - KernelArena::SetCurrentKernelArena(kernel_arena_); -} - -KernelScope::KernelScope(KernelArena* arena_) - : kernel_arena_(arena_), - old_kernel_arena_(KernelArena::GetCurrentKernelArena()), - owning_(false) { - KernelArena::SetCurrentKernelArena(kernel_arena_); -} - -KernelScope::~KernelScope() { - if (KernelArena::GetCurrentKernelArena() != kernel_arena_) { - // This should be an error, but it gets triggered in - // caffe2/benchmarks/static_runtime:static_runtime_cpptest - TORCH_WARN("KernelScope() destructed out of order, leaking memory"); - return; - } - KernelArena::SetCurrentKernelArena(old_kernel_arena_); - if (owning_) { - delete kernel_arena_; - } -} - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/mem_arena.h b/torch/csrc/jit/tensorexpr/mem_arena.h deleted file mode 100644 index a39ab6f0068c7..0000000000000 --- a/torch/csrc/jit/tensorexpr/mem_arena.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once -#include -#include - -namespace torch { -namespace jit { -namespace tensorexpr { - -class KernelScopedObject; - -// An arena that manages all the underlying kernel-scoped objects. -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -class KernelArena { - public: - static KernelArena* GetCurrentKernelArena(); - static void SetCurrentKernelArena(KernelArena* new_arena); - TORCH_API KernelArena() = default; - TORCH_API ~KernelArena(); - KernelArena(const KernelArena&) = delete; - KernelArena& operator=(const KernelArena&) = delete; - - private: - friend class KernelScopedObject; - std::vector kernel_objects_; // owned -}; - -// A RAII convenience wrapper on top of a kernel. -// It either creates or takes an existing Kernel and sets it as the current -// Kernel. When this object is destroyed, the previous Kernel is set as current, -// and the created kernel is freed. If the kernel was passed, it stays alive. -class KernelScope { - public: - TORCH_API KernelScope(); - TORCH_API explicit KernelScope(KernelArena* arena_); - TORCH_API ~KernelScope(); - KernelScope(const KernelScope&) = delete; - KernelScope& operator=(const KernelScope&) = delete; - - private: - KernelArena* kernel_arena_; // maybe owned - KernelArena* old_kernel_arena_; // previous arena, restored in destructor - bool owning_; // determines whether the arena will be freed along with - // the scope object -}; - -// The base object managed by the Kernel. -// The object must be created through "new", and when the Kernel is destroyed, -// All its registered objects are destroyed through "delete". -class TORCH_API KernelScopedObject { - public: - KernelScopedObject(); - virtual ~KernelScopedObject() = default; - - KernelScopedObject(const KernelScopedObject&) = delete; - KernelScopedObject& operator=(const KernelScopedObject&) = delete; -}; - -} // namespace tensorexpr -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp index 8f6f2b106b1b2..3f77041f1a202 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp @@ -76,12 +76,16 @@ std::vector AccessInfo::getIndices() const { void AccessInfo::addDependency(const std::shared_ptr& write) { auto res = dependencies_.emplace(write->id(), write); - TORCH_INTERNAL_ASSERT(res.second); + TORCH_INTERNAL_ASSERT( + res.second, + buildErrorMessage("Duplicate entry in mem dep checker in the fuser.")); } void AccessInfo::addDependent(const std::shared_ptr& read) { auto res = dependents_.emplace(read->id(), read); - TORCH_INTERNAL_ASSERT(res.second); + TORCH_INTERNAL_ASSERT( + res.second, + buildErrorMessage("Duplicate entry in mem dep checker in the fuser.")); } bool AccessInfo::hasDependency(const std::shared_ptr& info) const { @@ -185,13 +189,13 @@ void AccessInfo::dumpDOT(std::ostream& os) const { if (bounds_.size() > 0) { for (size_t i = 0; i < bounds_.size() - 1; ++i) { os << *IRSimplifier::simplify( - alloc(bounds_[i].end, alloc(1))) + alloc(bounds_[i].end, immLike(bounds_[i].end, 1))) << ", "; } size_t i = bounds_.size() - 1; os << *IRSimplifier::simplify( - alloc(bounds_[i].end, alloc(1))); + alloc(bounds_[i].end, immLike(bounds_[i].end, 1))); os << "]\"\n "; } if (isWrite()) { @@ -590,7 +594,10 @@ bool executionSafetyCheck( if (aStrides.empty() || oStrides.empty()) { return false; } - TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size()); + TORCH_INTERNAL_ASSERT( + info->bounds().size() == other->bounds().size(), + buildErrorMessage( + "Dimension mismatch for two accesses in mem dep checker in the fuser.")); for (size_t b = 0; b < info->bounds().size(); ++b) { ExprPtr aIndexStride = aStrides[b]; ExprPtr oIndexStride = oStrides[b]; @@ -632,7 +639,7 @@ bool executionSafetyCheck( // Invert the startDiff so mod works. if (diffNegative != strideNegative) { startDiff = - IRSimplifier::simplify(alloc(alloc(0), startDiff)); + IRSimplifier::simplify(alloc(immLike(startDiff, 0), startDiff)); } // If both accesses have the same stride, and the difference in start @@ -650,7 +657,7 @@ bool executionSafetyCheck( CompareSelectOperation op = strideNegative ? kLT : kGT; ExprPtr check = IRSimplifier::simplify( - alloc(startDiff, alloc(0), op)); + alloc(startDiff, immLike(startDiff, 0), op)); // If the start difference modulo the minimum stride is offset from that // stride, then the ranges have distinct strides. @@ -731,7 +738,7 @@ void MemDependencyChecker::visit(ForPtr v) { for (const auto i : c10::irange(indices.size())) { VarFinder vf; if (vf.find(indices[i]).count(var) == 0) { - loopIndicesStride[i] = alloc(0); + loopIndicesStride[i] = immLike(indices[i], 0); } else { // If we've previously swapped the start and end of this bound, we // should apply the substitution to the reverse of the bounds. @@ -740,19 +747,19 @@ void MemDependencyChecker::visit(ForPtr v) { SubstituteInClone(info->bounds()[i].end, {{var, v->start()}})); info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].start, - {{var, alloc(v->stop(), alloc(1))}})); + {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); } else { info->bounds()[i].start = IRSimplifier::simplify( SubstituteInClone(info->bounds()[i].start, {{var, v->start()}})); info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone( info->bounds()[i].end, - {{var, alloc(v->stop(), alloc(1))}})); + {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); } ExprPtr zeroStep = indices[i]; ExprPtr oneStep = SubstituteInClone( - indices[i], {{var, alloc(var, alloc(1))}}); + indices[i], {{var, alloc(var, immLike(var, 1))}}); loopIndicesStride[i] = IRSimplifier::simplify(alloc(oneStep, zeroStep)); @@ -785,7 +792,7 @@ void MemDependencyChecker::visit(ForPtr v) { bound.start = IRSimplifier::simplify( SubstituteInClone(bound.start, {{var, v->start()}})); bound.end = IRSimplifier::simplify(SubstituteInClone( - bound.end, {{var, alloc(v->stop(), alloc(1))}})); + bound.end, {{var, alloc(v->stop(), immLike(v->stop(), 1))}})); // If the start < end then swap the order of the bound. ExprPtr diff = @@ -1037,8 +1044,8 @@ void MemDependencyChecker::insertBuffers( IndexBounds bounds; for (auto d : b->dims()) { bounds.push_back( - {alloc(0), - IRSimplifier::simplify(alloc(d, alloc(1)))}); + {immLike(d, 0), + IRSimplifier::simplify(alloc(d, immLike(d, 1)))}); } auto info = std::make_shared(nextAccess_++, type, nullptr, var, bounds); @@ -1126,8 +1133,9 @@ void MemDependencyChecker::visit(AllocatePtr v) { // avoid failing the bound check. But this is not the correct approach and // should be fixed. ExprPtr flat_size = buf_flat_size(v->buf()); - flat_size = IRSimplifier::simplify(alloc(flat_size, alloc(1))); - bounds.push_back({alloc(0), flat_size}); + flat_size = + IRSimplifier::simplify(alloc(flat_size, immLike(flat_size, 1))); + bounds.push_back({immLike(flat_size, 0), flat_size}); auto info = std::make_shared( nextAccess_++, AccessType::Alloc, nullptr, var, bounds); @@ -1149,7 +1157,11 @@ void MemDependencyChecker::visit(FreePtr v) { VarPtr var = v->buffer_var(); auto it = intermediates_.find(var); - TORCH_INTERNAL_ASSERT(it != intermediates_.end()); + TORCH_INTERNAL_ASSERT( + it != intermediates_.end(), + buildErrorMessage( + "Expected to find '" + var->name_hint() + + "' in intermediate vars in mem dep checker in the fuser.")); IndexBounds bounds = it->second->bounds(); auto info = std::make_shared( diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index 5363d2fc5ae93..1965b05009125 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -299,7 +299,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { DependencySet getAllReadsWithin(StmtOrExprPtr v) { DependencySet reads; auto insertAllReads = [&](const auto& nodes) { - for (auto* l : nodes) { + for (auto l : nodes) { auto bound = exprToAccess_.equal_range(l); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isRead()) { @@ -324,7 +324,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor { // writes just Store currently. auto stores = NodeFinder::find(v); - for (auto* s : stores) { + for (auto s : stores) { auto bound = stmtToAccess_.equal_range(s); for (auto it = bound.first; it != bound.second; ++it) { if (it->second->isWrite()) { diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index c4af83a8cc6f4..51d323f4130a4 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -16,7 +16,7 @@ void assert_dims_constant(const BufHandle& buf) { using InitFunc = std::function&)>; -Tensor* conv2d_depthwise_static( +Tensor conv2d_depthwise_static( BufHandle input, BufHandle weight, const InitFunc& init_func, @@ -45,7 +45,7 @@ Tensor* conv2d_depthwise_static( auto OH = (H - R + 2 * pad) / stride + 1; auto OW = (W - S + 2 * pad) / stride + 1; - Tensor* conv = Reduce( + Tensor conv = Reduce( "conv2d_depthwise", {{N, "n"}, {K, "k"}, {OH, "oh"}, {OW, "ow"}}, Sum(), @@ -83,7 +83,7 @@ Tensor* conv2d_depthwise_static( } else if (R == 3 && stride == 1 && pad == 1) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr main, peeled; - auto loops = nest.getAllLoopNestsWritingToBuf(conv->buf()); + auto loops = nest.getAllLoopNestsWritingToBuf(conv.buf()); main = loops[1][kLoopW]; nest.sliceHead(main, 1, &peeled, &main); nest.sliceTail(main, 1, &main, &peeled); @@ -92,10 +92,10 @@ Tensor* conv2d_depthwise_static( nest.sliceTail(main, 1, &main, &peeled); } - return new Tensor(conv->buf(), nest.root_stmt()); + return Tensor(conv.buf(), nest.root_stmt()); } -Tensor* conv2d_depthwise_dynamic( +Tensor conv2d_depthwise_dynamic( BufHandle input, BufHandle weight, const InitFunc& init_func, @@ -144,7 +144,7 @@ Tensor* conv2d_depthwise_dynamic( } // namespace -Tensor* conv2d_depthwise( +Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, @@ -158,7 +158,7 @@ Tensor* conv2d_depthwise( return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups); } -Tensor* conv2d_depthwise( +Tensor conv2d_depthwise( BufHandle input, BufHandle weight, int stride, @@ -170,7 +170,7 @@ Tensor* conv2d_depthwise( return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups); } -Tensor* conv2d_depthwise( +Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, @@ -206,7 +206,7 @@ Tensor* conv2d_depthwise( groups); } -Tensor* conv2d_depthwise( +Tensor conv2d_depthwise( BufHandle input, BufHandle weight, ExprHandle N, diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.h b/torch/csrc/jit/tensorexpr/operators/conv2d.h index 14612fb17ee74..4c2215b38d868 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.h +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.h @@ -7,7 +7,7 @@ namespace jit { namespace tensorexpr { // An API to compute 2D depthwise convolutions with bias. -TORCH_API Tensor* conv2d_depthwise( +TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, @@ -16,14 +16,14 @@ TORCH_API Tensor* conv2d_depthwise( int groups); // An API to compute 2D depthwise convolutions without bias. -TORCH_API Tensor* conv2d_depthwise( +TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, int stride, int pad, int groups); -TORCH_API Tensor* conv2d_depthwise( +TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, @@ -39,7 +39,7 @@ TORCH_API Tensor* conv2d_depthwise( ExprHandle pad, ExprHandle groups); -TORCH_API Tensor* conv2d_depthwise( +TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, ExprHandle N, diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 23cb45564c97c..581514cdcb095 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeMatmul( +Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -44,13 +44,13 @@ Tensor* computeMatmul( }, {{size_a[1], "K"}}); } else { - return new Tensor( + return Tensor( ResultBuf.node(), ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {})); } } -Tensor* computeAddMM( +Tensor computeAddMM( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -59,7 +59,7 @@ Tensor* computeAddMM( dtype = Dtype(*outputType); } BufHandle ResultBuf("addmm", outputShape, dtype); - return new Tensor( + return Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.h b/torch/csrc/jit/tensorexpr/operators/matmul.h index 35b30f4168914..0b52ad65c43c8 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.h +++ b/torch/csrc/jit/tensorexpr/operators/matmul.h @@ -6,11 +6,11 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeMatmul( +Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); -Tensor* computeAddMM( +Tensor computeAddMM( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); diff --git a/torch/csrc/jit/tensorexpr/operators/norm.cpp b/torch/csrc/jit/tensorexpr/operators/norm.cpp index d96ebcd9447db..2e19d735d1809 100644 --- a/torch/csrc/jit/tensorexpr/operators/norm.cpp +++ b/torch/csrc/jit/tensorexpr/operators/norm.cpp @@ -4,7 +4,7 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeBatchNorm( +Tensor computeBatchNorm( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -38,11 +38,15 @@ Tensor* computeBatchNorm( constant(inputs[7]) // eps }; + ExprHandle weight = FloatImm::make(1); + ExprHandle bias = FloatImm::make(0); if (hasWeight) { - exprInputs.push_back(tensorOrConstant(inputs[1], {c})); + weight = tensorOrConstant(inputs[1], {c}); + exprInputs.push_back(weight); } if (hasBias) { - exprInputs.push_back(tensorOrConstant(inputs[2], {c})); + bias = tensorOrConstant(inputs[2], {c}); + exprInputs.push_back(bias); } promoteInputs(exprInputs); @@ -50,18 +54,7 @@ Tensor* computeBatchNorm( ExprHandle mean = exprInputs[1]; ExprHandle var = exprInputs[2]; ExprHandle eps = exprInputs[3]; - ExprHandle weight = FloatImm::make(1); - ExprHandle bias = FloatImm::make(0); - - if (hasWeight) { - weight = exprInputs[4]; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - if (hasBias) { - bias = exprInputs[5]; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) auto inv_var = rsqrt(var + eps); auto alpha = inv_var * weight; auto beta = bias - mean * alpha; diff --git a/torch/csrc/jit/tensorexpr/operators/norm.h b/torch/csrc/jit/tensorexpr/operators/norm.h index 98d53b4c306e3..7f1412f0aecd0 100644 --- a/torch/csrc/jit/tensorexpr/operators/norm.h +++ b/torch/csrc/jit/tensorexpr/operators/norm.h @@ -6,7 +6,7 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeBatchNorm( +Tensor computeBatchNorm( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.cpp b/torch/csrc/jit/tensorexpr/operators/reduction.cpp index c1f3f7f4f2630..fe5cb6d286bd5 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/operators/reduction.cpp @@ -19,7 +19,7 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeSum( +Tensor computeSum( const std::vector& inputs, const c10::optional& outputType) { std::vector axes; @@ -100,7 +100,7 @@ Tensor* computeSum( reductionDims); } -Tensor* computeMean( +Tensor computeMean( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -120,13 +120,13 @@ Tensor* computeMean( mean_dims_expr.emplace_back(idx); } } - return new Tensor( + return Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, "nnc_aten_mean", {InputBuf}, mean_dims_expr)); } -Tensor* computeAdaptiveAvgPool2d( +Tensor computeAdaptiveAvgPool2d( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType) { @@ -137,7 +137,7 @@ Tensor* computeAdaptiveAvgPool2d( BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype); // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) auto out_size_param = c10::get(inputs[1]); - return new Tensor( + return Tensor( ResultBuf.node(), ExternalCall::make( ResultBuf, diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.h b/torch/csrc/jit/tensorexpr/operators/reduction.h index 29f051f323b28..d76bac6aa34a1 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.h +++ b/torch/csrc/jit/tensorexpr/operators/reduction.h @@ -6,14 +6,14 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeSum( +TORCH_API Tensor computeSum( const std::vector& inputs, const c10::optional& outputType); -Tensor* computeMean( +TORCH_API Tensor computeMean( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); -Tensor* computeAdaptiveAvgPool2d( +TORCH_API Tensor computeAdaptiveAvgPool2d( const std::vector& inputs, const std::vector& outputShape, const c10::optional& outputType); diff --git a/torch/csrc/jit/tensorexpr/operators/softmax.cpp b/torch/csrc/jit/tensorexpr/operators/softmax.cpp index d6cb6c0d7d089..c1c2872cc4efe 100644 --- a/torch/csrc/jit/tensorexpr/operators/softmax.cpp +++ b/torch/csrc/jit/tensorexpr/operators/softmax.cpp @@ -6,7 +6,7 @@ namespace tensorexpr { using namespace torch::jit::tensorexpr; -Tensor* computeSoftmax( +Tensor computeSoftmax( const std::vector& inputs, const std::vector& outputShape, bool log_softmax) { @@ -111,48 +111,43 @@ Tensor* computeSoftmax( Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); - return exp(inp - max->load(remove_softmax_dim_index(indices))); + return exp(inp - max.load(remove_softmax_dim_index(indices))); }); auto sum = Reduce( "aten_softmax_sum", non_softmax_dims, Sum(), [&](ParameterList& indices) { - return e->load(move_softmax_dim_index_to_pos(indices)); + return e.load(move_softmax_dim_index_to_pos(indices)); }, {output_dims[softmax_dim]}); if (!log_softmax) { auto result = Compute("aten_softmax", output_dims, [&](ParameterList& indices) { - return e->load(indices) / - sum->load(remove_softmax_dim_index(indices)); + return e.load(indices) / sum.load(remove_softmax_dim_index(indices)); }); - return new Tensor( - result->buf(), + return Tensor( + result.buf(), alloc(std::vector( - {max->stmt(), e->stmt(), sum->stmt(), result->stmt()}))); + {max.stmt(), e.stmt(), sum.stmt(), result.stmt()}))); } auto log_sum = Compute( "aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) { - return log(sum->load(indices)); + return log(sum.load(indices)); }); auto result = Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); auto non_softmax_indices = remove_softmax_dim_index(indices); - return inp - max->load(non_softmax_indices) - - log_sum->load(non_softmax_indices); + return inp - max.load(non_softmax_indices) - + log_sum.load(non_softmax_indices); }); - return new Tensor( - result->buf(), + return Tensor( + result.buf(), alloc(std::vector( - {max->stmt(), - e->stmt(), - sum->stmt(), - log_sum->stmt(), - result->stmt()}))); + {max.stmt(), e.stmt(), sum.stmt(), log_sum.stmt(), result.stmt()}))); } } // namespace tensorexpr diff --git a/torch/csrc/jit/tensorexpr/operators/softmax.h b/torch/csrc/jit/tensorexpr/operators/softmax.h index 07ddd0f95b355..b74a867a91b9b 100644 --- a/torch/csrc/jit/tensorexpr/operators/softmax.h +++ b/torch/csrc/jit/tensorexpr/operators/softmax.h @@ -6,7 +6,7 @@ namespace torch { namespace jit { namespace tensorexpr { -Tensor* computeSoftmax( +Tensor computeSoftmax( const std::vector& inputs, const std::vector& outputShape, bool log_softmax); diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h index 08aef01c7d310..22d90b9981b82 100644 --- a/torch/csrc/jit/tensorexpr/reduction.h +++ b/torch/csrc/jit/tensorexpr/reduction.h @@ -171,7 +171,7 @@ inline ExprHandle maximumVal(ScalarType type) { #define MAX_BY_TYPE_CASE(Type, Name) \ case ScalarType::Name: \ return ExprHandle(std::numeric_limits::max()); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, MAX_BY_TYPE_CASE) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE) #undef MAX_BY_TYPE_CASE default: throw unsupported_dtype(); @@ -184,7 +184,7 @@ inline ExprHandle minimumVal(ScalarType type) { #define MAX_BY_TYPE_CASE(Type, Name) \ case ScalarType::Name: \ return ExprHandle(std::numeric_limits::min()); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, MAX_BY_TYPE_CASE) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE) #undef MAX_BY_TYPE_CASE default: throw unsupported_dtype(); diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 07aee209e6e53..c4c495762a79d 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -18,7 +18,7 @@ void AccessInfo::addStore(StorePtr store, const std::shared_ptr& scope) { last_usage_ = store; store_cost_ = - IRSimplifier::simplify(alloc(store_cost_, alloc(1))); + IRSimplifier::simplify(alloc(store_cost_, immLike(store_cost_, 1))); stores_.push_back(store); conditionId_ = scope->conditionId(); @@ -34,7 +34,8 @@ void AccessInfo::addLoad( first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage; last_usage_ = usage; - load_cost_ = IRSimplifier::simplify(alloc(load_cost_, alloc(1))); + load_cost_ = + IRSimplifier::simplify(alloc(load_cost_, immLike(load_cost_, 1))); loads_.push_back(load); conditionId_ = scope->conditionId(); @@ -42,8 +43,14 @@ void AccessInfo::addLoad( } void AccessInfo::merge(const std::shared_ptr& other) { - TORCH_INTERNAL_ASSERT(hash_ == other->hash()); - TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + TORCH_INTERNAL_ASSERT( + hash_ == other->hash(), + buildErrorMessage( + "Expected hashes to match in registerizer in the fuser.")); + TORCH_INTERNAL_ASSERT( + indices_.size() == other->indices().size(), + buildErrorMessage( + "Expected ranks to match in registerizer in the fuser.")); last_usage_ = other->last_usage(); for (auto s : other->stores()) { @@ -67,7 +74,10 @@ void AccessInfo::merge(const std::shared_ptr& other) { bool AccessInfo::overlaps(const std::shared_ptr& other) { // All accesses to a buf must have the same dimensionality. - TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size()); + TORCH_INTERNAL_ASSERT( + indices_.size() == other->indices().size(), + buildErrorMessage( + "Expected ranks to match in registerizer in the fuser.")); auto& other_indices = other->indices(); @@ -668,8 +678,10 @@ StmtPtr RegisterizerReplacer::mutate(StorePtr v) { ExprPtr new_val = v->value()->accept_mutator(this); - return alloc( - info->replacement().var_wrapper, std::vector({}), new_val); + v->set_value(new_val); + v->set_buf(info->replacement().var_wrapper); + v->set_indices({}); + return v; } StmtPtr RegisterizerReplacer::mutate(BlockPtr v) { diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index 0b4a2e4c5361c..7e4914fbc4aa7 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -14,14 +14,15 @@ namespace tensorexpr { class Placeholder; // The common base between all statement node. -class TORCH_API Stmt : public KernelScopedObject { +class TORCH_API Stmt : public std::enable_shared_from_this { public: Stmt() = default; + virtual ~Stmt() = default; virtual void accept(IRVisitor* visitor) = 0; virtual StmtPtr accept_mutator(IRMutator* mutator) = 0; StmtPtr get_parent() const { - return parent_; + return parent_ ? parent_->getptr() : nullptr; } /* @@ -34,12 +35,15 @@ class TORCH_API Stmt : public KernelScopedObject { static StmtPtr clone(StmtPtr s); protected: - static void set_parent(StmtPtr s, StmtPtr new_parent) { + static void set_parent(StmtPtr s, Stmt* new_parent) { s->parent_ = new_parent; } + std::shared_ptr getptr() { + return shared_from_this(); + } private: - StmtPtr parent_ = nullptr; + Stmt* parent_ = nullptr; }; template @@ -47,7 +51,7 @@ class StmtNode : public Stmt { public: using StmtNodeBase = StmtNode; void accept(IRVisitor* visitor) override { - visitor->visit(static_to(this)); + visitor->visit(static_to(getptr())); } StmtPtr accept_mutator(IRMutator* mutator) override; StmtNode() = default; @@ -55,7 +59,7 @@ class StmtNode : public Stmt { template StmtPtr StmtNode::accept_mutator(IRMutator* mutator) { - return mutator->mutate(static_to(this)); + return mutator->mutate(static_to(getptr())); } // Concrete Stmt classes @@ -193,7 +197,7 @@ class TORCH_API Block : public StmtNode { } void clear() { - for (auto* s : stmts_) { + for (auto s : stmts_) { set_parent(s, nullptr); } stmts_.clear(); @@ -281,7 +285,7 @@ class TORCH_API Block : public StmtNode { // returns the immediate child containing statement s. StmtPtr getEnclosedRoot(StmtPtr s) const { - while (s && s->get_parent() != this) { + while (s && s->get_parent().get() != this) { s = s->get_parent(); } return s; diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 9df70f81be4a9..7a219fe728757 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -31,8 +31,8 @@ StmtPtr Tensor::constructStmt( for (const auto i : c10::irange(reduce_ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = reduce_ndim - i - 1; - s = alloc( - reduce_args[dim_index], alloc(0), reduce_dims[dim_index], s); + auto const& dim = reduce_dims[dim_index]; + s = alloc(reduce_args[dim_index], immLike(dim, 0), dim, s); } if (init_expr) { StorePtr init_stmt = alloc(buf(), indices, init_expr); @@ -43,12 +43,13 @@ StmtPtr Tensor::constructStmt( for (const auto i : c10::irange(ndim)) { // Going in reverse order: from innermost loop to the outermost size_t dim_index = ndim - i - 1; - s = alloc(args[dim_index], alloc(0), buf()->dim(dim_index), s); + auto const& dim = buf()->dim(dim_index); + s = alloc(args[dim_index], immLike(dim, 0), dim, s); } return s; } -Tensor* Compute( +Tensor Compute( const std::string& name, const std::vector& dim_args, const std::function&)>& body_func) { @@ -57,10 +58,10 @@ Tensor* Compute( unpack_dim_args(dim_args, &dims, &args); ExprPtr body = body_func(VarVectorToVarHandleVector(args)).node(); BufPtr buf = alloc(name, dims, body->dtype()); - return new Tensor(buf, args, body); + return Tensor(buf, args, body); } -Tensor* Compute( +Tensor Compute( const std::string& name, const std::vector& dim_args, const std::function& body_func) { @@ -73,10 +74,10 @@ Tensor* Compute( unpack_dim_args(dim_args, &dims, &args); ExprPtr body = body_func(VarHandle(args[0])).node(); BufPtr buf = alloc(name, dims, body->dtype()); - return new Tensor(buf, args, body); + return Tensor(buf, args, body); } -Tensor* Compute( +Tensor Compute( const std::string& name, const std::vector& dim_args, const std::function& @@ -89,10 +90,10 @@ Tensor* Compute( unpack_dim_args(dim_args, &dims, &args); ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); BufPtr buf = alloc(name, dims, body->dtype()); - return new Tensor(buf, args, body); + return Tensor(buf, args, body); } -Tensor* Compute( +Tensor Compute( const std::string& name, const std::vector& dim_args, const std::function< @@ -108,10 +109,10 @@ Tensor* Compute( body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) .node(); BufPtr buf = alloc(name, dims, body->dtype()); - return new Tensor(buf, args, body); + return Tensor(buf, args, body); } -Tensor* Compute( +Tensor Compute( const std::string& name, const std::vector& dim_args, const std::function(name, dims, body->dtype()); - return new Tensor(buf, args, body); + return Tensor(buf, args, body); } -Tensor* Reduce( +Tensor Reduce( const std::string& name, const std::vector& dim_args, const Reducer& reducer, @@ -149,7 +150,7 @@ Tensor* Reduce( reduce_args); } -Tensor* Reduce( +Tensor Reduce( const std::string& name, const std::vector& dim_args, const Reducer& reducer, @@ -163,17 +164,17 @@ Tensor* Reduce( reduce_args); } -Tensor* Reduce( +Tensor Reduce( const std::string& name, const std::vector& dim_args, const Reducer& reducer, - Tensor* tensor, + Tensor tensor, const std::vector& reduce_args) { return Reduce( name, dim_args, reducer, - [&](ParameterList& p) { return tensor->load(p); }, + [&](ParameterList& p) { return tensor.load(p); }, reduce_args); } diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 3eb02c69bda78..8d8ffe5cfee44 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -12,7 +12,7 @@ namespace torch { namespace jit { namespace tensorexpr { -class TORCH_API Tensor : KernelScopedObject { +class TORCH_API Tensor { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) Tensor(BufPtr buf, const std::vector& args, ExprPtr body) @@ -42,9 +42,9 @@ class TORCH_API Tensor : KernelScopedObject { } template - inline ExprHandle load(const std::vector& args); + inline ExprHandle load(const std::vector& args) const; template - inline ExprHandle load(const Ts&... ts); + inline ExprHandle load(const Ts&... ts) const; private: StmtPtr constructStmt( @@ -134,22 +134,22 @@ class Placeholder { std::vector strides_; }; -TORCH_API Tensor* Compute( +TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dim_args, const std::function& body_func); -TORCH_API Tensor* Compute( +TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dim_args, const std::function& body_func); -TORCH_API Tensor* Compute( +TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dim_args, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func); -TORCH_API Tensor* Compute( +TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dim_args, const std::function& body_func); -TORCH_API Tensor* Compute( +TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dim_args, const std::function&)>& body_func); @@ -179,7 +179,7 @@ inline void unpack_dim_args( // Handle reductions over a Reducer and a body_func which produces values. template -Tensor* Reduce( +Tensor Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, @@ -207,7 +207,7 @@ Tensor* Reduce( .node(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) BufPtr func_result = alloc(func_name, dims, body->dtype()); - return new Tensor(func_result, vars, body); + return Tensor(func_result, vars, body); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) @@ -227,13 +227,12 @@ Tensor* Reduce( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ReduceOpPtr reduce_op = reducer(func_result, body, output_args, reduce_vars); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Tensor* t = - new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); + Tensor t = Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); return t; } template -Tensor* Reduce( +Tensor Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, @@ -250,7 +249,7 @@ Tensor* Reduce( // Overload which allows inline lambda functions for the body_func. template -Tensor* Reduce( +Tensor Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, @@ -260,14 +259,14 @@ Tensor* Reduce( } // Overload for the common case of all dimensions of a Placeholder. -TORCH_API Tensor* Reduce( +TORCH_API Tensor Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, const Placeholder& buffer, const std::vector& reduce_args); -TORCH_API Tensor* Reduce( +TORCH_API Tensor Reduce( const std::string& name, const std::vector& dim_args, const Reducer& reducer, @@ -276,22 +275,22 @@ TORCH_API Tensor* Reduce( // Overload for the common case of all dimensions of a prevously Computed // Tensor. -TORCH_API Tensor* Reduce( +TORCH_API Tensor Reduce( const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, - Tensor* tensor, + Tensor tensor, const std::vector& reduce_args); template -inline ExprHandle Tensor::load(const Ts&... ts) { +inline ExprHandle Tensor::load(const Ts&... ts) const { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector params({ExprHandle(ts)...}); return Load::make(BufHandle(this->buf()), params); } template -inline ExprHandle Tensor::load(const std::vector& args) { +inline ExprHandle Tensor::load(const std::vector& args) const { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector params(args.begin(), args.end()); return Load::make(BufHandle(this->buf()), params); diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 304a317076c05..c924bded3543c 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -61,7 +61,6 @@ void initTensorExprBindings(PyObject* module) { // Tensor Expr Classes auto te = m.def_submodule("_te"); - py::class_(te, "KernelScope").def(py::init<>()); auto dtype_class = py::class_(te, "Dtype").def(py::init(&parsePythonDtype)); @@ -70,7 +69,7 @@ void initTensorExprBindings(PyObject* module) { #define DTYPE_SINGLETON_ACCESSOR(ctype, name) \ dtype_class.def_property_readonly_static( \ #name, [](py::object) { return k##name; }); // NOLINT - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_SINGLETON_ACCESSOR) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_SINGLETON_ACCESSOR) #undef DTYPE_SINGLETON_ACCESSOR auto expr_handle_class = @@ -145,7 +144,7 @@ void initTensorExprBindings(PyObject* module) { #define EXPRHANDLE_CTOR(ctype, name) \ expr_handle_class.def_static(#ctype, [](ctype v) { return ExprHandle(v); }); - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, EXPRHANDLE_CTOR) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_CTOR) #undef EXPRHANDLE_CTOR py::class_(te, "VarHandle") @@ -184,21 +183,19 @@ void initTensorExprBindings(PyObject* module) { [](Placeholder& self, const std::vector& args, const ExprHandle& val) { return self.store(args, val); }) + .def("data", [](Placeholder& self) { return BufHandle(self.data()); }); + py::class_(te, "Tensor") .def( - "data", - [](Placeholder& self) { return BufHandle(self.data()); }, - py::return_value_policy::reference); - py::class_>(te, "Tensor") - .def(py::init( - [](BufHandle& b, StmtPtr s) { return new Tensor(b.node(), s); })) + py::init([](BufHandle& b, StmtPtr s) { return Tensor(b.node(), s); })) .def( "load", [](Tensor& self, const std::vector& v) { return self.load(v); }) .def("buf", [](Tensor& self) { return BufHandle(self.buf()); }) - .def("stmt", &Tensor::stmt, py::return_value_policy::reference); - py::class_(te, "Cast").def_static("make", &Cast::make); + .def("stmt", &Tensor::stmt); + py::class_>(te, "Cast") + .def_static("make", &Cast::make); py::class_(te, "DimArg") .def(py::init()) @@ -270,7 +267,7 @@ void initTensorExprBindings(PyObject* module) { [](const std::string& func_name, const std::vector& dim_args, const Reducer& reducer, - Tensor* buffer, + Tensor buffer, const std::vector& reduce_args) { return Reduce(func_name, dim_args, reducer, buffer, reduce_args); }, @@ -321,7 +318,7 @@ void initTensorExprBindings(PyObject* module) { }, py::return_value_policy::reference); - py::class_>(te, "Stmt") + py::class_>(te, "Stmt") .def(py::init([](const std::vector& stmts) { return tensorexpr::Block::make(stmts); })) @@ -330,22 +327,18 @@ void initTensorExprBindings(PyObject* module) { ss << self; return ss.str(); }); - py::class_>(te, "Store") + py::class_>(te, "Store") .def_static( "make", [](const BufHandle& buf, std::vector& indices, const ExprHandle& value) { return Store::make(buf, indices, value); - }, - py::return_value_policy::reference); + }); - py::class_>(te, "For") - .def( - "index_var", - [](For& self) { return VarHandle(self.var()); }, - py::return_value_policy::reference) - .def("body", &For::body, py::return_value_policy::reference) + py::class_>(te, "For") + .def("index_var", [](For& self) { return VarHandle(self.var()); }) + .def("body", &For::body) .def("set_parallel", &For::set_parallel) .def( "set_gpu_block_index", @@ -362,38 +355,31 @@ void initTensorExprBindings(PyObject* module) { [](const VarHandle& var, const ExprHandle& start, const ExprHandle& stop, - StmtPtr body) { return For::make(var, start, stop, body); }, - py::return_value_policy::reference); + StmtPtr body) { return For::make(var, start, stop, body); }); - py::class_>(te, "Cond") + py::class_>(te, "Cond") .def_static( "make", [](const ExprHandle& condition, StmtPtr true_stmt, StmtPtr false_stmt) { - return alloc(condition.node(), true_stmt, false_stmt); - }, - py::return_value_policy::reference) - .def("true_stmt", &Cond::true_stmt, py::return_value_policy::reference) - .def("false_stmt", &Cond::false_stmt, py::return_value_policy::reference); + return Cond::make(condition, true_stmt, false_stmt); + }) + .def("true_stmt", &Cond::true_stmt) + .def("false_stmt", &Cond::false_stmt); - py::class_< - tensorexpr::Block, - Stmt, - std::unique_ptr>(te, "Block") + py::class_>( + te, "Block") .def(py::init([](const std::vector& stmts) { return tensorexpr::Block::make(stmts); })) - .def( - "stmts", - &tensorexpr::Block::stmts, - py::return_value_policy::reference); - py::class_>( + .def("stmts", &tensorexpr::Block::stmts); + py::class_>( te, "ExternalCall") - .def(py::init(&ExternalCall::make), py::return_value_policy::reference); + .def(py::init(&ExternalCall::make)); py::class_(te, "LoopNest") - .def(py::init&>()) + .def(py::init&>()) .def(py::init([](StmtPtr s, const std::vector& bufs) { std::unordered_set buf_nodes; for (auto& buf : bufs) { @@ -405,9 +391,7 @@ void initTensorExprBindings(PyObject* module) { .def("prepare_for_codegen", &LoopNest::prepareForCodegen) .def( "get_loop_body_for", - [](const LoopNest& self, Tensor* t) { - return self.getLoopBodyFor(t); - }, + [](const LoopNest& self, Tensor t) { return self.getLoopBodyFor(t); }, py::return_value_policy::reference) .def( "get_loop_body_for", @@ -417,7 +401,7 @@ void initTensorExprBindings(PyObject* module) { py::return_value_policy::reference) .def( "get_loops_for", - [](const LoopNest& self, Tensor* t) { + [](const LoopNest& self, Tensor t) { return self.getLoopStmtsFor(t); }, py::return_value_policy::reference) @@ -773,12 +757,12 @@ void initTensorExprBindings(PyObject* module) { py::class_(te, "BufferArg") .def(py::init()) - .def(py::init()) + .def(py::init()) .def(py::init()) .def(py::init()); py::implicitly_convertible(); - py::implicitly_convertible(); + py::implicitly_convertible(); py::implicitly_convertible(); py::implicitly_convertible(); diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index 5cef86a2dfe26..e75ecd9744d61 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -16,7 +16,7 @@ Dtype Dtype::scalar_dtype() const { // NOLINTNEXTLINE #define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1); -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_DEFINE) +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE) #undef DTYPE_DEFINE @@ -28,7 +28,7 @@ Dtype ToDtype(ScalarType type) { #define TYPE_CASE(_1, n) \ case ScalarType::n: \ return k##n; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE) + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE case ScalarType::Undefined: @@ -56,7 +56,7 @@ int Dtype::byte_size() const { scalar_size = sizeof(Type); \ break; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE); #undef TYPE_CASE default: throw std::runtime_error( @@ -77,6 +77,8 @@ std::string Dtype::ToCppString() const { return "bool"; case ScalarType::Half: return "half"; + case ScalarType::BFloat16: + return "__nv_bfloat16"; default: throw unsupported_dtype(); } diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h index 00cd50db288b3..3716a0a1cd559 100644 --- a/torch/csrc/jit/tensorexpr/types.h +++ b/torch/csrc/jit/tensorexpr/types.h @@ -75,7 +75,7 @@ extern TORCH_API Dtype kHandle; #define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name; -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_DTYPE_DECLARATION) +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION) #undef NNC_DTYPE_DECLARATION template @@ -86,7 +86,7 @@ TORCH_API Dtype ToDtype(); inline Dtype ToDtype() { \ return k##name; \ } -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_TODTYPE_DECLARATION) +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION) #undef NNC_TODTYPE_DECLARATION TORCH_API Dtype ToDtype(ScalarType type); diff --git a/torch/csrc/utils/crash_handler.cpp b/torch/csrc/utils/crash_handler.cpp index 2de22be0d2e86..8fb318b265a83 100644 --- a/torch/csrc/utils/crash_handler.cpp +++ b/torch/csrc/utils/crash_handler.cpp @@ -3,8 +3,16 @@ #include #ifdef ADD_BREAKPAD_SIGNAL_HANDLER -#include +#ifdef __linux__ +#include #include +#elif __APPLE__ +#include +#elif _WIN32 +#include +#else +#error unsupported platform +#endif #endif #include @@ -16,9 +24,10 @@ namespace crash_handler { #ifdef ADD_BREAKPAD_SIGNAL_HANDLER static std::unique_ptr handler; // NOLINT -static std::string minidump_directory; // NOLINT +static STRING_TYPE minidump_directory; // NOLINT static bool enabled_for_exceptions = false; // NOLINT +#if __linux__ bool dump_callback( const google_breakpad::MinidumpDescriptor& descriptor, void* context, @@ -28,10 +37,45 @@ bool dump_callback( } return succeeded; } +#elif __APPLE__ -void enable_minidumps(const std::string& dir) { +bool dump_callback( + const char* dump_dir, + const char* minidump_id, + void* context, + bool succeeded) { + if (succeeded) { + std::cerr << "Wrote minidump to " << dump_dir << "/" << minidump_id + << ".dmp" << std::endl; + } + return succeeded; +} +#elif _WIN32 +bool dump_callback( + const wchar_t* dump_path, + const wchar_t* minidump_id, + void* context, + EXCEPTION_POINTERS* exinfo, + MDRawAssertionInfo* assertion, + bool succeeded) { + if (succeeded) { + // Printing with wcerr inserts spaces between all the characters for some + // reason. If someone figures that out then we can get rid of the std::string + // conversions here. + std::wstring dump_path_ws(dump_path); + std::string dump_path_string(dump_path_ws.begin(), dump_path_ws.end()); + std::wstring minidump_id_ws(minidump_id); + std::string minidump_id_string(minidump_id_ws.begin(), minidump_id_ws.end()); + std::cerr << "Wrote minidump to " << dump_path_string << "\\" << minidump_id_string << ".dmp" << std::endl; + } + return succeeded; +} +#endif + +void enable_minidumps(const STRING_TYPE& dir) { minidump_directory = dir; - // The constructor here registers the actual signal handler +// The constructor here registers the actual signal handler +#ifdef __linux__ handler = std::make_unique( google_breakpad::MinidumpDescriptor(minidump_directory), nullptr, @@ -39,13 +83,30 @@ void enable_minidumps(const std::string& dir) { nullptr, true, -1); +#elif __APPLE__ + handler = std::make_unique( + /*dump_path=*/minidump_directory.c_str(), + /*filter=*/nullptr, + /*callback=*/dump_callback, + /*callback_context=*/nullptr, + /*install_handler=*/true, + /*port_name=*/nullptr); +#elif _WIN32 + handler = std::make_unique( + /*dump_path=*/minidump_directory.c_str(), + /*filter=*/nullptr, + /*callback=*/dump_callback, + /*callback_context=*/nullptr, + /*handler_types=*/ + google_breakpad::ExceptionHandler::HandlerType::HANDLER_ALL); +#endif } void disable_minidumps() { handler.reset(); } -const std::string& get_minidump_directory() { +const STRING_TYPE& get_minidump_directory() { if (handler == nullptr) { AT_ERROR( "Minidump handler is uninintialized, make sure to call enable_minidumps first"); @@ -78,18 +139,16 @@ void enable_minidumps_on_exceptions() { #else // On unspported systems we can't do anything, so stub out everything. -void enable_minidumps(const std::string& dir) { - AT_ERROR( - "Minidump collection is currently only implemented for Linux platforms"); +void enable_minidumps(const STRING_TYPE& dir) { + AT_ERROR("Compiled without minidump support"); } void disable_minidumps() { // Purposefully do nothing } -const std::string& get_minidump_directory() { - AT_ERROR( - "Minidump collection is currently only implemented for Linux platforms"); +const STRING_TYPE& get_minidump_directory() { + AT_ERROR("Compiled without minidump support"); } bool is_enabled_on_exceptions() { @@ -97,13 +156,11 @@ bool is_enabled_on_exceptions() { } void write_minidump() { - AT_ERROR( - "Minidump collection is currently only implemented for Linux platforms"); + AT_ERROR("Compiled without minidump support"); } void enable_minidumps_on_exceptions() { - AT_ERROR( - "Minidump collection is currently only implemented for Linux platforms"); + AT_ERROR("Compiled without minidump support"); } #endif diff --git a/torch/csrc/utils/crash_handler.h b/torch/csrc/utils/crash_handler.h index 5fe0503b2ed00..dc11945195372 100644 --- a/torch/csrc/utils/crash_handler.h +++ b/torch/csrc/utils/crash_handler.h @@ -5,10 +5,16 @@ namespace torch { namespace crash_handler { +#ifdef _WIN32 +typedef std::wstring STRING_TYPE; +#else +typedef std::string STRING_TYPE; +#endif + // Set up a handler that writes minidumps to 'dir' on signals. This is not // necessary to call unless you want to change 'dir' to something other than // the default '/tmp/pytorch_crashes'. -TORCH_API void enable_minidumps(const std::string& dir); +TORCH_API void enable_minidumps(const STRING_TYPE& dir); // Enable minidumps when passing exceptions up to Python. By default these don't // do anything special, but it can be useful to write out a minidump on @@ -19,7 +25,7 @@ TORCH_API void enable_minidumps_on_exceptions(); TORCH_API void disable_minidumps(); // Get the directory that minidumps will be written to -TORCH_API const std::string& get_minidump_directory(); +TORCH_API const STRING_TYPE& get_minidump_directory(); // These are TORCH_API'ed since they are used from libtorch_python.so TORCH_API bool is_enabled_on_exceptions(); diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index c9a1e4a39aeef..d132185ccaefb 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -810,4 +810,14 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove */ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector* overloaded_args, int argnum, bool throw_error); +/* Given an argument that is definitely a tensor and is definitely overloaded, + * append it to the overloaded arguments list. Use this instead of + * is_tensor_and_append_overloaded in situations where you have a PyObject + * and you know it definitely is a Tensor and it is definitely overloaded. + * + * 'overloaded_args': the vector to append the overloaded args + * 'obj': the input tensor that is overloaded + */ +void append_overloaded_arg(std::vector* overloaded_args, PyObject* obj); + } // namespace torch diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 29e112fc67abd..80d9e108643b4 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -16,7 +16,8 @@ import threading from typing import List, Optional, Tuple, Union, Any from ._utils import _get_device_index, _dummy_type -from .streams import Stream, Event, _Graph, _graph_pool_handle +from .graphs import CUDAGraph, graph_pool_handle, graph, make_graphed_callables +from .streams import Stream, Event from .. import device as _device import torch._C @@ -78,6 +79,15 @@ def is_available() -> bool: # be initialized return torch._C._cuda_getDeviceCount() > 0 +def is_bf16_supported(): + r"""Returns a bool indicating if the current CUDA device supports dtype bfloat16""" + cu_vers = torch.version.cuda + if cu_vers is not None: + cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11 + + else: + cuda_maj_decide = False + return torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 and cuda_maj_decide def _sleep(cycles): torch._C._cuda_sleep(cycles) diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index e9bfe06a0a352..ca8a2fcaf29d5 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -13,8 +13,8 @@ class autocast(torch.autocast_mode.autocast): See :class:`torch.autocast`. ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)`` """ - def __init__(self, enabled=True, fast_dtype=torch.float16): - super().__init__("cuda", enabled=enabled, fast_dtype=fast_dtype) + def __init__(self, enabled=True, dtype=torch.float16): + super().__init__("cuda", enabled=enabled, dtype=dtype) # Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py new file mode 100644 index 0000000000000..ff8a07f989f9d --- /dev/null +++ b/torch/cuda/graphs.py @@ -0,0 +1,408 @@ +import gc +import torch + +from ._utils import _dummy_type + + +if not hasattr(torch._C, '_CudaStreamBase'): + # Define dummy base classes + torch._C.__dict__['_CUDAGraph'] = _dummy_type('_CUDAGraph') + torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle') + +from torch._C import _CUDAGraph # noqa: F401 +from torch._C import _graph_pool_handle + + +# Python shim helps Sphinx process docstrings more reliably. +def graph_pool_handle(): + r""" + Returns an opaque token representing the id of a graph memory pool. + See :ref:`Graph memory management`. + + .. warning:: + This API is a prototype and may change in future releases. + """ + return _graph_pool_handle() + + +# Python shim helps Sphinx process docstrings more reliably. +class CUDAGraph(torch._C._CUDAGraph): + r""" + Wrapper around a CUDA graph. + + .. warning:: + This API is a prototype and may change in future releases. + """ + def __new__(cls): + return super(CUDAGraph, cls).__new__(cls) + + def __init__(self): + super(CUDAGraph, self).__init__() + + def capture_begin(self, pool=None): + r""" + Begins capturing CUDA work on the current stream. + + Typically, you shouldn't call ``capture_begin`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_begin`` internally. + + Arguments: + pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or + :meth:`other_Graph_instance.pool()`) that hints this graph may share memory + with the indicated pool. See :ref:`Graph memory management`. + """ + # I'm not sure if pybind11 converts a None arg to the default defined on the C++ side, + # so I'm not taking any chances. + if pool is None: + super(CUDAGraph, self).capture_begin() + else: + super(CUDAGraph, self).capture_begin(pool) + + def capture_end(self): + r""" + Ends CUDA graph capture on the current stream. + After ``capture_end``, ``replay`` may be called on this instance. + + Typically, you shouldn't call ``capture_end`` yourself. + Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`, + which call ``capture_end`` internally. + """ + super(CUDAGraph, self).capture_end() + + def replay(self): + r""" + Replays the CUDA work captured by this graph. + """ + super(CUDAGraph, self).replay() + + def reset(self): + r""" + Deletes the graph currently held by this instance. + """ + super(CUDAGraph, self).reset() + + def pool(self): + r""" + Returns an opaque token representing the id of this graph's memory pool. + This id can optionally be passed to another graph's ``capture_begin``, + which hints the other graph may share the same memory pool. + """ + return super(CUDAGraph, self).pool() + + +class graph(object): + r""" + Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` + object for later replay. + + See :ref:`CUDA Graphs ` for a general introduction, + detailed use, and constraints. + + Arguments: + cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture. + pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or + :meth:`other_Graph_instance.pool()`) hinting this graph's capture + may share memory from the specified pool. See :ref:`Graph memory management`. + stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context. + If not supplied, ``graph`` sets its own internal side stream as the current stream in the context. + + .. note:: + For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture + used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. + + .. warning:: + This API is a prototype and may change in future releases. + """ + default_capture_stream = None + + def __init__(self, + cuda_graph, + pool=None, + stream=None): + # Lazy-init of default_capture_stream helps avoid circular-import errors. + # Not thread safe, but graphs already have the general (explicitly documented) + # restriction that only one capture may be underway at a time in the process. + if self.__class__.default_capture_stream is None: + self.__class__.default_capture_stream = torch.cuda.Stream() + + self.pool = () if pool is None else (pool,) + self.capture_stream = stream if stream is not None else self.__class__.default_capture_stream + assert self.capture_stream is not None + self.stream_ctx = torch.cuda.stream(self.capture_stream) + self.cuda_graph = cuda_graph + + def __enter__(self): + # Free as much memory as we can for the graph + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Stackoverflow seems comfortable with this pattern + # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 + self.stream_ctx.__enter__() + + self.cuda_graph.capture_begin(*self.pool) + + + def __exit__(self, exc_type, exc_value, traceback): + self.cuda_graph.capture_end() + self.stream_ctx.__exit__(exc_type, exc_value, traceback) + # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() + + +def make_graphed_callables(callables, sample_args): + r""" + Accepts callables (functions or :class:`nn.Module`\ s) + and returns graphed versions. + + Each graphed callable's forward pass runs its source callable's + forward CUDA work as a CUDA graph inside a single autograd node. + + The graphed callable's forward pass also appends + a backward node to the autograd graph. During backward, this node runs the + callable's backward work as a CUDA graph. + + Therefore, each graphed callable should be a drop-in replacement for its source callable + in an autograd-enabled training loop. + + See :ref:`Partial-network capture` for detailed use and constraints. + + If you pass a tuple of several callables, their captures will use the same memory pool. + See :ref:`Graph memory management` for when this is appropriate. + + Arguments: + callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph. + See :ref:`Graph memory management` for when passing a tuple of callables + is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order + they'll run in the live workload. + sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable. + If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors. + If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors. + + .. note:: + The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state + that's expected for the corresponding real input in the training loop. + + .. warning:: + This API is a prototype and may change in future releases. + + .. warning:: + ``sample_args`` for each callable must be a tuple of Tensors. Other types and keyword args + are not allowed. + + .. warning:: + Returned callables do not support higher order differentiation (e.g., double backward). + + .. warning:: + In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters + may be trainable. Buffers must have ``requires_grad=False``. + + .. warning:: + After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`, + you may not add or remove any of that Module's parameters or buffers. + + .. warning:: + :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks + registered on them at the time they are passed. However, registering hooks on modules *after* passing them + through :func:`~torch.cuda.make_graphed_callables` is allowed. + + .. warning:: + When running a graphed callable, you must pass its arguments in the same order and format + they appeared in that callable's ``sample_args``. + + .. warning:: + All Tensor outputs of graphed callables must require grad. + """ + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + for c, args in zip(callables, sample_args): + if isinstance(c, torch.nn.Module): + assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \ + "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \ + "on modules after passing them through make_graphed_callables is allowed." + assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \ + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \ + "``requires_grad=False``." + assert all(isinstance(arg, torch.Tensor) for arg in args), "In the prototype API, sample_args " + \ + "for each callable must be a tuple of Tensors. Other types and keyword args are not allowed." + + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in sample_args] + per_callable_module_params = [tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables] + per_callable_static_input_surfaces = [sample_args[i] + per_callable_module_params[i] + for i in range(len(callables))] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + + mempool = graph_pool_handle() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for func, args, static_input_surface in zip(callables, + sample_args, + per_callable_static_input_surfaces): + for _ in range(3): + outputs = func(*args) + outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs + grad_inputs = torch.autograd.grad(outputs=outputs, + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(torch.empty_like(o) for o in outputs), + only_inputs=True, + allow_unused=False) + del outputs, grad_inputs + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_was_tensor = [] + for func, args, fwd_graph in zip(callables, + sample_args, + fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + # Assumes model output is a tensor or tuple of tensors + if isinstance(outputs, torch.Tensor): + per_callable_output_was_tensor.append(True) + outputs = (outputs,) + else: + per_callable_output_was_tensor.append(False) + + per_callable_static_outputs.append(outputs) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph, module_params in \ + zip(reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + reversed(per_callable_module_params)): + + # For now, assumes all static_outputs require grad + assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." + static_grad_outputs = tuple(torch.empty_like(o) for o in static_outputs) + + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad(outputs=static_outputs, + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=static_grad_outputs, + only_inputs=True, + allow_unused=False) + + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad. + # I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) + per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function(fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_was_tensor, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs): + class Graphed(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + for g, grad in zip(static_grad_outputs, grads): + if g is None: + assert grad is None + else: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return tuple(b.detach() if b is not None else b for b in static_grad_inputs) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + out = Graphed.apply(*(user_args + module_params)) + return out[0] if output_was_tensor else out + + return functionalized + + # Put together the final graphed callables + ret = [] + for i, func in enumerate(callables): + graphed = make_graphed_autograd_function(fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_was_tensor[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i]) + + if isinstance(func, torch.nn.Module): + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + else: + return orig_fwd(*user_args) + return new_fwd + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] + ret.append(func) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 0f983728f630a..2b4cc479e095f 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -8,8 +8,6 @@ # Define dummy base classes torch._C.__dict__['_CudaStreamBase'] = _dummy_type('_CudaStreamBase') torch._C.__dict__['_CudaEventBase'] = _dummy_type('_CudaEventBase') - torch._C.__dict__['_CudaGraphBase'] = _dummy_type('_CudaGraphBase') - torch._C.__dict__['_graph_pool_handle'] = _dummy_type('_graph_pool_handle') class Stream(torch._C._CudaStreamBase): r"""Wrapper around a CUDA stream. @@ -226,6 +224,3 @@ def __repr__(self): return ''.format(self._as_parameter_.value) else: return '' - -_Graph = torch._C._CudaGraphBase -_graph_pool_handle = torch._C._graph_pool_handle diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md index 5e426466ec67d..6cbaea694f215 100644 --- a/torch/distributed/CONTRIBUTING.md +++ b/torch/distributed/CONTRIBUTING.md @@ -85,7 +85,6 @@ python test/distributed/test_store.py python test/distributed/test_pg_wrapper.py # Run distributed tests, including tests for Distributed Data Parallel. -python test/run_test.py --verbose -i distributed/test_distributed_fork python test/run_test.py --verbose -i distributed/test_distributed_spawn # Run the RPC test suite for the TensorPipeAgent. diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index d9833159dc9de..4f8646d54268c 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -1,26 +1,31 @@ -from typing import List +# coding=utf-8 -import torch -from torch.distributed._sharding_spec import ShardingSpec from .api import ( + CreateOp, Shard, ShardedTensor, ShardedTensorMetadata, + TensorInitParams, + TensorProperties, load_with_process_group, ) +from torch.distributed._sharding_spec import ShardingSpec +from typing import List +import torch -def empty( - sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False): + +def empty(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False): """ - Creates an empty :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion. + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. Args: sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification @@ -49,18 +54,229 @@ def empty( Returns: A :class:`ShardedTensor` object on each rank """ + tensor_properties = TensorProperties(dtype=dtype, layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, ) + tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + +def ones(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False): + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties(dtype=dtype, layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, ) + tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties) return ShardedTensor( sharding_spec, *size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, - memory_format=memory_format, + tensor_init_params=tensor_init_params, process_group=process_group, init_rrefs=init_rrefs, ) + +def rand(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False): + """ + Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the + interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False): + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full(sharding_spec: ShardingSpec, + size, + fill_value=torch.types.Number, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False): + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) – the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, + ) + tensor_init_params = TensorInitParams( + create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + def init_from_local_shards( local_shards: List[Shard], sharded_tensor_metadata: ShardedTensorMetadata, diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py index ca9a05abffa06..d6b7a54732445 100644 --- a/torch/distributed/_sharded_tensor/api.py +++ b/torch/distributed/_sharded_tensor/api.py @@ -1,6 +1,7 @@ import collections from contextlib import contextmanager from dataclasses import dataclass, field +from enum import Enum from typing import ( Dict, List @@ -21,6 +22,7 @@ check_tensor, validate_non_overlapping_shards_metadata ) +from torch.types import Number # Tracking for sharded tensor objects. _sharded_tensor_lock = threading.Lock() @@ -57,6 +59,24 @@ class Shard(object): tensor: torch.Tensor metadata: ShardMetadata +@dataclass +class TensorProperties(object): + """ Properties used to create :class:`Tensor` """ + + # Regular tensor fields + dtype: torch.dtype = field(default=torch.get_default_dtype()) + layout: torch.layout = field(default=torch.strided) + requires_grad: bool = False + memory_format: torch.memory_format = field(default=torch.contiguous_format) + pin_memory: bool = False + + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + @dataclass class ShardedTensorMetadata(object): """ @@ -69,50 +89,55 @@ class ShardedTensorMetadata(object): # Size of each dim of the overall Tensor. size: torch.Size = field(default=torch.Size([])) - # Regular tensor fields - dtype: torch.dtype = field(default=torch.get_default_dtype()) - layout: torch.layout = field(default=torch.strided) - requires_grad: bool = False - memory_format: torch.memory_format = field(default=torch.contiguous_format) - pin_memory: bool = False + tensor_properties: TensorProperties = field( + default=TensorProperties(dtype=torch.get_default_dtype(), + layout=torch.strided, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=False)) def __getstate__(self): # Since torch.memory_format cannot be pickled! - if self.memory_format == torch.contiguous_format: - mem_format_encoding = 0 - elif self.memory_format == torch.channels_last: - mem_format_encoding = 1 - elif self.memory_format == torch.preserve_format: - mem_format_encoding = 1 + memory_format = self.tensor_properties.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT else: - raise RuntimeError(f'Invalid torch.memory_format: {self.memory_format}') + raise RuntimeError(f'Invalid torch.memory_format: {memory_format}') + # Keep old seriazation to ensure backward compatibility return ( self.shards_metadata, self.size, - self.dtype, - self.layout, - self.requires_grad, + self.tensor_properties.dtype, + self.tensor_properties.layout, + self.tensor_properties.requires_grad, mem_format_encoding, - self.pin_memory, + self.tensor_properties.pin_memory, ) def __setstate__( self, state, ): - (self.shards_metadata, self.size, self.dtype, self.layout, - self.requires_grad, mem_format_encoding, self.pin_memory) = state - - if mem_format_encoding == 0: - self.memory_format = torch.contiguous_format - elif mem_format_encoding == 1: - self.memory_format = torch.channels_last - elif mem_format_encoding == 2: - self.memory_format = torch.preserve_format + (self.shards_metadata, self.size, dtype, layout, requires_grad, mem_format_encoding, pin_memory) = state + + if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format else: raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}') + self.tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + memory_format=memory_format, pin_memory=pin_memory, ) + def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int): with _sharded_tensor_lock: @@ -123,6 +148,32 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]] _sharded_tensor_map[sharded_tensor_id]._register_remote_shards(rrefs, rpc_rank) +class CreateOp(Enum): + EMPTY = 0 + FULL = 1 + ONES = 2 + RAND = 3 + ZEROS = 4 + + +@dataclass +class TensorInitParams(object): + """ Container for list of common params to create new local tensor. """ + + create_op: CreateOp + + # needed when create_op is FULL + # default set to False (not None) since None is incompatible with Number. + fill_value: Number = field(default=False) + + tensor_properties: TensorProperties = field( + default=TensorProperties(dtype=torch.get_default_dtype(), + layout=torch.strided, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=False)) + + class ShardedTensor(object): """ ShardedTensor is an abstraction to represent Tensors that are sharded @@ -136,8 +187,9 @@ class ShardedTensor(object): ShardedTensor doesn't provide any Tensor like operations but is a wrapper providing the Tensor representing the local shard and the global metadata. Using these, users can build their custom distributed sharded computations - on top of this primitive. The local shards are all initialized using - :meth:`torch.empty`. + on top of this primitive. The local shards are all initialized using the + create_op specified by tensor_init_params.create_op, e.g., torch.ones, or + torch.empty Args: sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification @@ -146,20 +198,7 @@ class ShardedTensor(object): tensor. Can be a variable number of arguments or a collection like a list or tuple. Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - memory_format (:class:`torch.memory_format`, optional): the desired memory format of - returned Tensor. Default: ``torch.contiguous_format``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. If specified the ShardedTensor is only - built on ranks that are part of this process group and the provided ``sharding_spec`` - is applied in the context of this process group. + tensor_init_params (:class: `TensorInitParams`): common params to create tensor. init_rrefs (bool, optional): Whether or not to initialize :class:`torch.distributed.rpc.RRef`s pointing to remote shards. Need to initialize the RPC Framework if specified as ``True``. @@ -170,11 +209,7 @@ def __init__( self, sharding_spec: ShardingSpec, *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, + tensor_init_params: TensorInitParams, process_group=None, init_rrefs=False, ): @@ -182,13 +217,16 @@ def __init__( # _process_group, _local_shards, etc. self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) - if dtype is None: - dtype = torch.get_default_dtype() + if tensor_init_params.tensor_properties is None: + raise ValueError('tensor_properties must not be None.') + + if tensor_init_params.tensor_properties.dtype is None: + tensor_init_params.tensor_properties.dtype = torch.get_default_dtype() - if layout != torch.strided: + if tensor_init_params.tensor_properties.layout != torch.strided: raise ValueError('Only torch.strided layout is currently supported') - if memory_format != torch.contiguous_format: + if tensor_init_params.tensor_properties.memory_format != torch.contiguous_format: raise ValueError('Only torch.contiguous_format memory_format is currently supported') if len(size) == 1 and isinstance(size[0], collections.Sequence): @@ -203,23 +241,9 @@ def __init__( self._sharding_spec = sharding_spec if isinstance(self._sharding_spec, ChunkShardingSpec): - self._init_chunked( - dims, - dtype, - layout, - requires_grad, - pin_memory, - memory_format, - ) + self._init_chunked(dims, tensor_init_params) elif isinstance(self._sharding_spec, EnumerableShardingSpec): - self._init_enumerable( - dims, - dtype, - layout, - requires_grad, - pin_memory, - memory_format, - ) + self._init_enumerable(dims, tensor_init_params) else: raise ValueError(f'Unsupported sharding_spec: {self._sharding_spec}') @@ -317,11 +341,12 @@ def _init_from_local_shards( init_rrefs=False, ): shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties if len(shards_metadata) == 0: raise ValueError("shards_metadata must not be empty!") - if sharded_tensor_metadata.layout != torch.strided: + if tensor_properties.layout != torch.strided: raise ValueError('Only torch.strided layout is currently supported') sharded_tensor = cls.__new__(cls) @@ -362,11 +387,11 @@ def _init_from_local_shards( assert shard_meta in local_shard_metadatas, \ "local shard metadata not in sharded_tensor_metadata!" - if local_shard_tensor.layout != sharded_tensor_metadata.layout: + if local_shard_tensor.layout != tensor_properties.layout: raise ValueError( - f'Local shard tensor layout does not match with sharded_tensor_metadata! ' + f'Local shard tensor layout does not match with tensor_properties! ' f'local shard tensor layout: {local_shard_tensor.dtype}, ' - f'sharded_tensor_metadata layout: {sharded_tensor_metadata.layout}' + f'tensor_properties layout: {tensor_properties.layout}' ) if not local_shard_tensor.is_contiguous(): @@ -379,11 +404,11 @@ def _init_from_local_shards( f'local ShardMetadata shard lengths: {shard_meta.shard_lengths}' ) - if local_shard_tensor.is_pinned() != sharded_tensor_metadata.pin_memory: + if local_shard_tensor.is_pinned() != tensor_properties.pin_memory: raise ValueError( - f'Local shard tensor pin_memory does not match with sharded_tensor_metadata! ' + f'Local shard tensor pin_memory does not match with tensor_properties! ' f'local shard tensor pin_memory: {local_shard_tensor.is_pinned()}, ' - f'sharded_tensor_metadata pin_memory: {sharded_tensor_metadata.pin_memory}' + f'tensor_properties pin_memory: {tensor_properties.pin_memory}' ) if local_shard_tensor.device != local_device: @@ -393,18 +418,18 @@ def _init_from_local_shards( f'local shard metadata placement device: {local_device}' ) - if local_shard_tensor.dtype != sharded_tensor_metadata.dtype: + if local_shard_tensor.dtype != tensor_properties.dtype: raise ValueError( - f'Local shard tensor dtype does not match with sharded_tensor_metadata! ' + f'Local shard tensor dtype does not match with tensor_properties! ' f'local shard tensor dtype: {local_shard_tensor.dtype}, ' - f'sharded_tensor_metadata dtype: {sharded_tensor_metadata.dtype}' + f'tensor_properties dtype: {tensor_properties.dtype}' ) - if local_shard_tensor.requires_grad != sharded_tensor_metadata.requires_grad: + if local_shard_tensor.requires_grad != tensor_properties.requires_grad: raise ValueError( - f'Local shard tensor requires_grad does not match with sharded_tensor_metadata! ' + f'Local shard tensor requires_grad does not match with tensor_properties! ' f'local shard tensor requires_grad: {local_shard_tensor.requires_grad}, ' - f'sharded_tensor_metadata requires_grad: {sharded_tensor_metadata.requires_grad}' + f'tensor_properties requires_grad: {tensor_properties.requires_grad}' ) # check if shards_metadata have overlap shards @@ -420,15 +445,7 @@ def _init_from_local_shards( sharded_tensor._post_init() return sharded_tensor - def _init_chunked( - self, - dims, - dtype, - layout, - requires_grad, - pin_memory, - memory_format, - ): + def _init_chunked(self, dims, tensor_init_params: TensorInitParams, ): current_rank = dist.get_rank(self._process_group) sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined] @@ -469,38 +486,15 @@ def _init_chunked( # Build the local shard for the current rank if it is involved in the sharding spec. if current_rank == rank: # Initialize the local shard. - local_shard = torch.empty( - *rank_dims, - dtype=dtype, - layout=layout, - device=local_device, - requires_grad=requires_grad, - memory_format=memory_format, - pin_memory=pin_memory, - ) - + local_shard = _create_tensor_from_params( + *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append(Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( - shards_metadata, - dims, - dtype, - layout, - requires_grad, - memory_format, - pin_memory, - ) + shards_metadata, dims, tensor_init_params.tensor_properties, ) - def _init_enumerable( - self, - dims, - dtype, - layout, - requires_grad, - pin_memory, - memory_format, - ): + def _init_enumerable(self, dims, tensor_init_params: TensorInitParams): # Validate the sharding spec is compatible with the tensor. check_tensor(self._sharding_spec.shards, dims) # type: ignore[attr-defined] @@ -513,28 +507,14 @@ def _init_enumerable( if current_rank == rank: # Initialize the local shard. - local_shard = torch.empty( - *shard_metadata.shard_lengths, - dtype=dtype, - layout=layout, - device=local_device, - requires_grad=requires_grad, - memory_format=memory_format, - pin_memory=pin_memory, - ) - + local_shard = _create_tensor_from_params( + *shard_metadata.shard_lengths, local_device=local_device, + tensor_init_params=tensor_init_params) self._local_shards.append(Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( - shards_metadata, - dims, - dtype, - layout, - requires_grad, - memory_format, - pin_memory, - ) + shards_metadata, dims, tensor_init_params.tensor_properties, ) def _parse_and_validate_remote_device(self, remote_device: torch.distributed._remote_device): @@ -590,6 +570,35 @@ def size(self) -> torch.Size: """ return self._metadata.size + def is_pinned(self) -> bool: + """ + Returns True if the sharded tensor (each local shard) resides in pinned memory. + """ + return self._metadata.tensor_properties.pin_memory + + def is_contiguous(self) -> bool: + """ + Returns True if the sharded tensor (each local shard) is contiguous in memory + in the order specified by memory format. + """ + return self._metadata.tensor_properties.memory_format == torch.contiguous_format + + @property + def shape(self): + return self._metadata.size + + @property + def requires_grad(self): + return self._metadata.tensor_properties.requires_grad + + @property + def dtype(self): + return self._metadata.tensor_properties.dtype + + @property + def layout(self): + return self._metadata.tensor_properties.layout + def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int): self._remote_shards[rpc_rank] = remote_shards @@ -672,3 +681,47 @@ def __setstate__(self, state): f'but at load time was {global_world_size}') self._post_init() + + +def _create_tensor_from_params(*size, local_device, tensor_init_params: TensorInitParams): + """ Helper to construct tensor from size, device and common params. """ + + create_op = tensor_init_params.create_op + dtype = tensor_init_params.tensor_properties.dtype + layout = tensor_init_params.tensor_properties.layout + requires_grad = tensor_init_params.tensor_properties.requires_grad + memory_format = tensor_init_params.tensor_properties.memory_format + pin_memory = tensor_init_params.tensor_properties.pin_memory + + if create_op == CreateOp.ONES: + return torch.ones(*size, dtype=dtype, layout=layout, + device=local_device, pin_memory=pin_memory, + requires_grad=requires_grad,) + elif create_op == CreateOp.EMPTY: + return torch.empty(*size, dtype=dtype, layout=layout, + device=local_device, requires_grad=requires_grad, + # NB: memory_format param is not accepted by torch.ones + memory_format=memory_format, pin_memory=pin_memory,) + elif tensor_init_params.create_op == CreateOp.ZEROS: + return torch.zeros(*size, + dtype=dtype, + layout=layout, + device=local_device, + pin_memory=pin_memory, + requires_grad=requires_grad,) + elif tensor_init_params.create_op == CreateOp.RAND: + return torch.rand(*size, + dtype=dtype, + layout=layout, + device=local_device, + pin_memory=pin_memory, + requires_grad=requires_grad,) + elif tensor_init_params.create_op == CreateOp.FULL: + return torch.full(size=size, + fill_value=tensor_init_params.fill_value, + layout=layout, + dtype=dtype, + requires_grad=requires_grad, + device=local_device, ) + else: + raise ValueError(f'Unsupported create_op: {tensor_init_params.create_op}') diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_sharding_spec/_internals.py index a519a9a3e2b7b..3f2ab2f1a4ea8 100644 --- a/torch/distributed/_sharding_spec/_internals.py +++ b/torch/distributed/_sharding_spec/_internals.py @@ -1,5 +1,6 @@ -from typing import List +from typing import List, Union from dataclasses import dataclass +from torch.distributed.remote_device import _remote_device import torch @@ -24,7 +25,7 @@ class ShardMetadata(object): shard_offsets: List[int] shard_lengths: List[int] - placement: torch.distributed._remote_device + placement: Union[str, _remote_device] def __post_init__(self): if isinstance(self.placement, str): diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 35ddf316e91c5..ff22a818f925d 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -5,6 +5,7 @@ from torch.nn.parallel import DistributedDataParallel from . import ( + debugging_hooks as debugging, default_hooks as default, powerSGD_hook as powerSGD, quantization_hooks as quantization, @@ -46,6 +47,9 @@ class DDPCommHookType(Enum): FP16_COMPRESS = partial( _ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook ) + BF16_COMPRESS = partial( + _ddp_comm_hook_wrapper, comm_hook=default.bf16_compress_hook + ) QUANTIZE_PER_TENSOR = partial( _ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook ) @@ -75,6 +79,9 @@ class DDPCommHookType(Enum): comm_hook=powerSGD.batched_powerSGD_hook, matrix_approximation_rank=2, ) + NOOP = partial( + _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook, + ) def register_ddp_comm_hook( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py new file mode 100644 index 0000000000000..0c60762caf2ed --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -0,0 +1,26 @@ +from typing import Any + +import torch +import torch.distributed as dist + + +def noop_hook(_: Any, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + """ + This DDP communication hook returns the a future that wraps the input, + so it is a noop that does not incur any communication overheads. + + This hook should **only** be used for headroom analysis of allreduce optimization, + instead of the normal gradient synchronization. + For example, if only less than 10% speedup of training time can be observed after this hook is registered, + it usually implies that allreduce is not a performance bottleneck for this case. + Such instrumentation can be particularly useful + if GPU traces cannot be easily retrieved or the trace analysis is complicated + some factors such as the overlap between allreduce and computation or the desynchronization across ranks. + + Example:: + >>> ddp_model.register_comm_hook(None, noop_hook) + """ + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 0642deace3565..d11e39b23f6f0 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -69,6 +69,41 @@ def decompress(fut): return fut.then(decompress) +# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress. +def bf16_compress_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision + `Brain floating point format `_ (``torch.bfloat16``) + and then divides it by the process group size. + It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + compressed_tensor = bucket.buffer().to(torch.bfloat16).div_(world_size) + + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()[0]) + return decompressed_tensor + + return fut.then(decompress) + class _OptimizerHookState(object): """ @@ -160,3 +195,40 @@ def decompress(fut): return fut.then(decompress) return fp16_compress_wrapper_hook + +def bf16_compress_wrapper( + hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] +) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: + """ + Warning: This API is experimental, and it requires NCCL version later than 2.9.6. + + This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision + `Brain floating point format `_ (``torch.bfloat16``), + and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. + + Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. + + Example:: + >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) + >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) + """ + + def bf16_compress_wrapper_hook( + hook_state, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + # Cast bucket tensor to BF16. + bucket.set_buffer(bucket.buffer().to(torch.bfloat16)) + + fut = hook(hook_state, bucket) + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()) + return decompressed_tensor + + # Decompress after hook has run. + return fut.then(decompress) + + return bf16_compress_wrapper_hook diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index 44ee422b9e92d..a2bbac2a25474 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -20,6 +20,9 @@ def average_parameters( return params_it1, params_it2 = itertools.tee(params) + # If the input parameters have different data types, + # packing these parameters will trigger an implicit type up-casting. + # The original parameter data types will be restored during the subsequent unpacking. flat_params = torch.cat([p.data.view(-1) for p in params_it1]) flat_params /= dist.get_world_size(group_to_use) # Make sure the allreduce will not conflict with any other ongoing process group. @@ -29,5 +32,6 @@ def average_parameters( offset = 0 for p in params_it2: - p.data = flat_params[offset : offset + p.numel()].view_as(p) + with torch.no_grad(): + p.set_(flat_params[offset : offset + p.numel()].view_as(p).type_as(p)) # type: ignore[call-overload] offset += p.numel() diff --git a/torch/distributed/algorithms/quantization.py b/torch/distributed/algorithms/quantization/quantization.py similarity index 73% rename from torch/distributed/algorithms/quantization.py rename to torch/distributed/algorithms/quantization/quantization.py index dead78af600b2..a5e9b4652a805 100644 --- a/torch/distributed/algorithms/quantization.py +++ b/torch/distributed/algorithms/quantization/quantization.py @@ -10,7 +10,12 @@ TORCH_HALF_MAX = torch.finfo(torch.float16).max class DQuantType(Enum): - FP16 = "fp16" + """ + Different quantization methods for auto_quantize API are identified here. + auto_quantize API currently supports fp16 and bfp16 methods. + """ + FP16 = "fp16", + BFP16 = "bfp16" def __str__(self) -> str: return self.value @@ -26,6 +31,8 @@ def _quantize_tensor(tensor, qtype): ) if (qtype == DQuantType.FP16): return _fp32_to_fp16_with_clamp(tensor) + elif (qtype == DQuantType.BFP16): + return torch.ops.q._FloatToBfloat16Quantized(tensor) else: raise RuntimeError( f'Quantization type {qtype} is not supported' @@ -38,13 +45,8 @@ def _quantize_tensor_list(tensor_list, qtype): raise RuntimeError( f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" ) - if (qtype == DQuantType.FP16): - quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] - return quantized_tensor_list - else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] + return quantized_tensor_list def _dequantize_tensor(tensor, qtype, quant_loss=None): if not isinstance(tensor, torch.Tensor): @@ -60,6 +62,13 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None): return tensor.float() else: return tensor.float() / quant_loss + elif (qtype == DQuantType.BFP16): + if tensor.dtype != torch.float16: + raise RuntimeError( + f"tensor dtype is {tensor.dtype} while expected to be FP16." + ) + else: + return torch.ops.q._Bfloat16QuantizedToFloat(tensor) else: raise RuntimeError( f'Quantization type {qtype} is not supported' @@ -73,29 +82,22 @@ def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): raise RuntimeError( f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" ) - elif (qtype == DQuantType.FP16): - dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] - return dequantized_tensor_list - else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] + return dequantized_tensor_list def auto_quantize(func, qtype, quant_loss=None): """ This is a prototype API that automatically quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. - Currently it only supports: - . FP16 quantization method + . FP16 and BFP16 quantization method supported for gloo and nccl backends . all_gather, all_to_all collective ops - + Note: BFP16 only supports 2D tensors. Args: func (callable): A function representing collective operations. qtype (QuantType): Quantization method quant_loss (float, optional): This can be used to improve accuracy in the dequantization. - Returns: (callable): the same collective as func but enables automatic quantization/dequantization. """ @@ -123,6 +125,16 @@ def wrapper(*args, **kwargs): for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t + elif (func == dist.all_to_all_single): + tensors = args[0] + out_splits = kwargs.get('out_splits', None) + in_splits = kwargs.get('in_splits', None) + # Quantizing the input/output tensor + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor(tensors, qtype) + dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) + for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): + tensors[i] = t else: raise RuntimeError( f"The collective op {func} is not supported yet" diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 1b1244d9e37d5..302114e1c7bb6 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,6 +1,7 @@ import contextlib import io import logging +import os import pickle import time import warnings @@ -9,28 +10,31 @@ import torch from torch._C._distributed_c10d import ( - AllreduceOptions, AllreduceCoalescedOptions, + AllreduceOptions, AllToAllOptions, BarrierOptions, BroadcastOptions, GatherOptions, PrefixStore, ProcessGroup, - ReduceOptions, ReduceOp, + ReduceOptions, ReduceScatterOptions, ScatterOptions, Store, + _DistributedDebugLevel, + _get_debug_mode, ) -from torch._C._distributed_c10d import _get_debug_mode, _DistributedDebugLevel from torch._six import string_classes +from .constants import default_pg_timeout +from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + # This module is wildcard imported from torch.distributed. # TODO: specify __all__ -from .constants import default_pg_timeout -from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401 _MPI_AVAILABLE = True _NCCL_AVAILABLE = True @@ -244,7 +248,9 @@ def _store_based_barrier(rank, store, timeout): ) ) - logger.info(f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes.") + logger.info( + f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes." + ) def _rank_not_in_group(group: ProcessGroup): @@ -384,6 +390,18 @@ def is_initialized(): return GroupMember.WORLD is not None +def is_torchelastic_launched(): + """ + Checks whether this process was launched with ``torch.distributed.elastic`` + (aka torchelastic). The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + def _get_default_group(): """ Getting the default process group created by init_process_group @@ -782,7 +800,8 @@ def destroy_process_group(group=None): def get_rank(group=None): """ - Returns the rank of current process group + Returns the rank of the current process in the provided ``group`` or the + default group if none was provided. Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to @@ -1778,8 +1797,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): is_nccl_backend = group_backend == Backend.NCCL current_device = None if device is not None: - if is_nccl_backend and device.type != 'cuda': - raise ValueError('device type must be cuda for nccl backend') + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") current_device = device else: current_device = torch.device("cpu") @@ -2229,7 +2248,9 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): if _rank_not_in_group(group): return - scatter_list = [t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list] + scatter_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list + ] tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) my_rank = get_rank() @@ -3026,9 +3047,7 @@ def new_subgroups( if rank in ranks_in_subgroup: cur_subgroup = subgroup logger.info( - "Rank {} is assigned to subgroup {}".format( - rank, ranks_in_subgroup - ) + "Rank {} is assigned to subgroup {}".format(rank, ranks_in_subgroup) ) return cur_subgroup, subgroups @@ -3139,8 +3158,6 @@ def new_subgroups_by_enumeration( rank_to_ranks_dict[rank] = ranks if my_rank == rank: cur_subgroup = subgroup - logging.info( - "Rank {} is assigned to subgroup {}".format(rank, ranks) - ) + logging.info("Rank {} is assigned to subgroup {}".format(rank, ranks)) return cur_subgroup, subgroups diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 6d389a7873a4a..d767233a2ae52 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -160,7 +160,7 @@ def __init__( # rank of the worker among all the workers with the same role # across all ``agent`` instances. - # Global rank is not stable between re-rendezvous. + # Role rank is not stable between re-rendezvous. self.role_rank: int = role_rank # total number of workers (globally). Due to elasticity diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 7746dbace9af5..ab0e0f3b7c874 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -165,7 +165,7 @@ def timestamp_isoformat(self): rank: ${rank} (local_rank: ${local_rank}) exitcode: ${exitcode} (pid: ${pid}) error_file: ${error_file} - msg: \"${message}\"""" + msg: ${message}""" # extra new lines before and after are intentional _MSG_FORMAT_TEMPLATE = """ @@ -258,6 +258,19 @@ def format_msg(self, boarder_delim="*", section_delim="="): def _format_failure( self, idx: int, rank: int, failure: ProcessFailure ) -> Tuple[str, int]: + if isinstance(failure.message, str): + msg = '"' + failure.message + '"' + else: + try: + dmp = json.dumps(failure.message, indent=2) + except ValueError: + msg = failure.message + else: + msg = os.linesep + # Indent by 4 chars. + for l in dmp.splitlines(): + msg += f" {l}{os.linesep}" + fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( idx=idx, time=failure.timestamp_isoformat(), @@ -266,7 +279,7 @@ def _format_failure( exitcode=failure.exitcode, pid=failure.pid, error_file=failure.error_file, - message=failure.message, + message=msg, ) width = 0 for line in fmt.split("\n"): diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 74586e9fd8523..2974355fae88c 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -23,7 +23,7 @@ def _write_error(e: BaseException, error_file: Optional[str]): "message": { "message": f"{type(e).__name__}: {e}", "extraInfo": { - "py_callstack": traceback.format_exc(), + "py_callstack": traceback.format_stack(), "timestamp": str(int(time.time())), }, } diff --git a/torch/distributed/elastic/utils/log_level.py b/torch/distributed/elastic/utils/log_level.py new file mode 100644 index 0000000000000..87ea0f7d64182 --- /dev/null +++ b/torch/distributed/elastic/utils/log_level.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def get_log_level() -> str: + """ + Return default log level for pytorch. + """ + return "WARNING" diff --git a/torch/distributed/elastic/utils/logging.py b/torch/distributed/elastic/utils/logging.py index 19c68c03cf552..e4f1345e4c339 100644 --- a/torch/distributed/elastic/utils/logging.py +++ b/torch/distributed/elastic/utils/logging.py @@ -12,6 +12,8 @@ import warnings from typing import Optional +from torch.distributed.elastic.utils.log_level import get_log_level + def get_logger(name: Optional[str] = None): """ @@ -32,7 +34,7 @@ def get_logger(name: Optional[str] = None): def _setup_logger(name: Optional[str] = None): log = logging.getLogger(name) - log.setLevel(os.environ.get("LOGLEVEL", "WARNING")) + log.setLevel(os.environ.get("LOGLEVEL", get_log_level())) return log diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 5fcb3eb44c126..6173abb2c9ecf 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -4,7 +4,7 @@ .. warning:: - This module is going to be deprecated in favor of :ref:`torch.distributed.run `. + This module is going to be deprecated in favor of :ref:`torchrun `. The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either @@ -97,9 +97,9 @@ >>> # your code to run 3. In your training program, you are supposed to call the following function -at the beginning to start the distributed backend. You need to make sure that -the init_method uses ``env://``, which is the only supported ``init_method`` -by this module. +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. :: @@ -147,6 +147,7 @@ from torch.distributed.run import get_args_parser, run + logger = logging.getLogger(__name__) @@ -176,12 +177,13 @@ def launch(args): def main(args=None): warnings.warn( "The module torch.distributed.launch is deprecated\n" - "and will be removed in future. Use torch.distributed.run.\n" - "Note that --use_env is set by default in torch.distributed.run.\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use_env is set by default in torchrun.\n" "If your script expects `--local_rank` argument to be set, please\n" "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" - "further instructions\n", FutureWarning + "further instructions\n", + FutureWarning, ) args = parse_args(args) launch(args) diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index ef26db64dbed8..fb3b160c8ebcc 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -288,11 +288,13 @@ def get_module_rref(self) -> rpc.RRef[nn.Module]: """ return self.module_rref + @torch.jit.export def __getstate__(self): raise RuntimeError( "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" ) + @torch.jit.export def __setstate__(self, state): raise RuntimeError( "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 5623a0b8d6841..0159aa35a5539 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -53,6 +53,55 @@ def __init__( # param group as it's not a common use case. self.param_group = {"params": params} + def step_param(self, param: Tensor, grad: Optional[Tensor]): + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[int] = [] + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + state = self.state[param] + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if self.amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step'].item()) + with torch.no_grad(): + F.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + beta1=self.defaults['beta1'], + beta2=self.defaults['beta2'], + lr=self.defaults['lr'], + weight_decay=self.defaults['weight_decay'], + eps=self.defaults['eps']) + def step(self, gradients: List[Optional[Tensor]]): params = self.param_group['params'] params_with_grad = [] diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 6a5b680e25011..6e430e273f951 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,17 +1,22 @@ try: from urllib.parse import urlparse, urlunparse except ImportError: - raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.") + raise ImportError( + "urllib cannot be found, urlparse from python2 is no longer supported." + ) -import torch._six as six import numbers import os import sys from datetime import timedelta -from typing import Optional, Dict, Union -from torch.distributed import FileStore, TCPStore, PrefixStore +from typing import Dict, Optional, Union + +import torch._six as six +from torch.distributed import FileStore, PrefixStore, Store, TCPStore + from .constants import default_pg_timeout + _rendezvous_handlers = {} @@ -73,7 +78,9 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): query_dict["world_size"] = world_size result = result._replace( - query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()])) + query="{}".format( + "&".join(["{}={}".format(k, v) for k, v in query_dict.items()]) + ) ) url = urlunparse(result) @@ -92,8 +99,9 @@ def _error(msg): result = urlparse(url) path = result.path - if sys.platform == 'win32': + if sys.platform == "win32": import urllib.request + full_path = result.netloc + result.path path = urllib.request.url2pathname(full_path) if path: @@ -119,7 +127,41 @@ def _error(msg): raise RuntimeError("Unable to perform rerendezvous using file:// method") -def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): +def _torchelastic_use_agent_store() -> bool: + return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) + + +def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: + """ + Smartly creates a c10d Store object on ``rank`` based on whether + we need to re-use agent store. The TCPStore server is assumed to be hosted + on ``hostname:port``. + + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that + the agent leader (node rank 0) hosts the TCPStore server (for which the + endpoint is specified by the given ``hostname:port``). Hence + ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). + + If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host + the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname + and port are correctly passed via ``hostname`` and ``port``. All + non-zero ranks will create and return a TCPStore client. + """ + + if _torchelastic_use_agent_store(): + attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] + tcp_store = TCPStore(hostname, port, world_size, False, timeout) + return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) + else: + start_daemon = rank == 0 + return TCPStore( + hostname, port, world_size, start_daemon, timeout, multi_tenant=True + ) + + +def _tcp_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): def _error(msg): return _rendezvous_error("tcp:// rendezvous: " + msg) @@ -136,18 +178,19 @@ def _error(msg): rank = int(query["rank"]) world_size = int(query["world_size"]) - start_daemon = rank == 0 assert result.hostname is not None - store = TCPStore( # type: ignore[call-arg] - result.hostname, result.port, world_size, start_daemon, timeout, multi_tenant=True - ) + + store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout) + yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it - raise RuntimeError("Unable to perform rerendezvous using tcp:// method") + raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") -def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): +def _env_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): def _error(msg): return _rendezvous_error("env:// rendezvous: " + msg) @@ -183,29 +226,13 @@ def _get_env_or_raise(env_var: str) -> str: master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) + store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout) - use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) - - if use_torchelastic_store == str(True): - attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] - worker_process_prefix = f"/worker/attempt_{attempt}" - # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed - # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread - # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False - tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout) - # Each if-else condition returns due to: https://github.com/python/mypy/issues/1191 - yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size) - else: - # Start the TCP store daemon on the rank 0 - start_daemon = rank == 0 - store = TCPStore( # type: ignore[call-arg] - master_addr, master_port, world_size, start_daemon, timeout, multi_tenant=True - ) - # Each if-else condition returns due to: https://github.com/python/mypy/issues/1191 - yield (store, rank, world_size) + yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it - raise RuntimeError("Unable to perform rerendezvous using env:// method") + raise RuntimeError("Unable to perform re-rendezvous using env:// method") + register_rendezvous_handler("tcp", _tcp_rendezvous_handler) register_rendezvous_handler("env", _env_rendezvous_handler) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 9fb88fa3a2c96..c6e84d6f65f4b 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -7,7 +7,7 @@ # LICENSE file in the root directory of this source tree. """ -``torch.distributed.run`` provides a superset of the functionality as ``torch.distributed.launch`` +``torchrun`` provides a superset of the functionality as ``torch.distributed.launch`` with the following additional functionalities: 1. Worker failures are handled gracefully by restarting all workers. @@ -18,33 +18,33 @@ -Transitioning from torch.distributed.launch to torch.distributed.run +Transitioning from torch.distributed.launch to torchrun ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -``torch.distributed.run`` supports the same arguments as ``torch.distributed.launch`` **except** +``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except** for ``--use_env`` which is now deprecated. To migrate from ``torch.distributed.launch`` -to ``torch.distributed.run`` follow these steps: +to ``torchrun`` follow these steps: 1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable. Then you need simply omit the ``--use_env`` flag, e.g.: - +--------------------------------------------------------------------+------------------------------------------------------+ - | ``torch.distributed.launch`` | ``torch.distributed.run`` | - +====================================================================+======================================================+ - | | | - | .. code-block:: shell-session | .. code-block:: shell-session | - | | | - | $ python -m torch.distributed.launch --use_env train_script.py | $ python -m torch.distributed.run train_script.py | - | | | - +--------------------------------------------------------------------+------------------------------------------------------+ + +--------------------------------------------------------------------+--------------------------------------------+ + | ``torch.distributed.launch`` | ``torchrun`` | + +====================================================================+============================================+ + | | | + | .. code-block:: shell-session | .. code-block:: shell-session | + | | | + | $ python -m torch.distributed.launch --use_env train_script.py | $ torchrun train_script.py | + | | | + +--------------------------------------------------------------------+--------------------------------------------+ 2. If your training script reads local rank from a ``--local_rank`` cmd argument. Change your training script to read from the ``LOCAL_RANK`` environment variable as demonstrated by the following code snippet: +-------------------------------------------------------+----------------------------------------------------+ - | ``torch.distributed.launch`` | ``torch.distributed.run`` | + | ``torch.distributed.launch`` | ``torchrun`` | +=======================================================+====================================================+ | | | | .. code-block:: python | .. code-block:: python | @@ -59,12 +59,12 @@ | | | +-------------------------------------------------------+----------------------------------------------------+ -The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torch.distributed.run``. -To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torch.distributed.run`` +The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``. +To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun`` please refer to: -* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torch.distributed.run`` compliant. -* the rest of this page for more information on the features of ``torch.distributed.run``. +* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant. +* the rest of this page for more information on the features of ``torchrun``. @@ -75,7 +75,7 @@ :: - >>> python -m torch.distributed.run + >>> torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_TRAINERS @@ -85,7 +85,7 @@ :: - >>> python -m torch.distributed.run + >>> torchrun --nnodes=$NUM_NODES --nproc_per_node=$NUM_TRAINERS --rdzv_id=$JOB_ID @@ -104,7 +104,7 @@ :: - >>> python -m torch.distributed.run + >>> torchrun --nnodes=1:4 --nproc_per_node=$NUM_TRAINERS --rdzv_id=$JOB_ID @@ -186,7 +186,7 @@ of the worker is specified in the ``WorkerSpec``. 5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to - ``--nproc_per_node`` specified on ``torch.distributed.run``. + ``--nproc_per_node`` specified on ``torchrun``. 6. ``WORLD_SIZE`` - The world size (total number of workers in the job). @@ -321,6 +321,7 @@ def train(): from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.launcher.api import LaunchConfig, elastic_launch + log = get_logger() @@ -595,7 +596,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str nproc_per_node = determine_local_world_size(args.nproc_per_node) if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: omp_num_threads = 1 - print( + log.warning( f"*****************************************\n" f"Setting OMP_NUM_THREADS environment variable for each process to be " f"{omp_num_threads} in default, to avoid your system being overloaded, " diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index cbe987e72c798..c03f0ad02d2c6 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -173,7 +173,9 @@ def _transform_to_independent(constraint): @biject_to.register(constraints.positive) +@biject_to.register(constraints.nonnegative) @transform_to.register(constraints.positive) +@transform_to.register(constraints.nonnegative) def _transform_to_positive(constraint): return transforms.ExpTransform() diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 99808b6b80beb..5eed19afd09ec 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -545,6 +545,7 @@ def check(self, value): real = _Real() real_vector = independent(real, 1) positive = _GreaterThan(0.) +nonnegative = _GreaterThanEq(0.) greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index 954ed6e0d3206..9adb641d7fcee 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -24,7 +24,7 @@ class Poisson(ExponentialFamily): Args: rate (Number, Tensor): the rate parameter """ - arg_constraints = {'rate': constraints.positive} + arg_constraints = {'rate': constraints.nonnegative} support = constraints.nonnegative_integer @property @@ -60,7 +60,7 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) rate, value = broadcast_all(self.rate, value) - return (rate.log() * value) - rate - (value + 1).lgamma() + return value.xlogy(rate) - rate - (value + 1).lgamma() @property def _natural_params(self): diff --git a/torch/functional.py b/torch/functional.py index ab8f70f6bffaf..63470cf2d443f 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -66,6 +66,7 @@ def broadcast_tensors(*tensors): tensor([[0, 1, 2], [0, 1, 2]]) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(broadcast_tensors, tensors, *tensors) return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] @@ -96,6 +97,7 @@ def broadcast_shapes(*shapes): Raises: RuntimeError: If shapes are incompatible. """ + # This wrapper exists to support variadic args. # TODO Movie this to C++ once the jit has better support for torch.Size. with torch.no_grad(): scalar = torch.zeros((), device="cpu") @@ -277,6 +279,7 @@ def einsum(*args): tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]]) """ + # This wrapper exists to support variadic args. if len(args) < 2: raise ValueError('einsum(): must specify the equation string and at least one operand, ' 'or at least one operand and its subscripts list') @@ -324,29 +327,61 @@ def parse_subscript(n: int) -> str: return _VF.einsum(equation, operands) # type: ignore[attr-defined] +# This wrapper exists to support variadic args. if TYPE_CHECKING: # The JIT doesn't understand Union, so only add type annotation for mypy def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]: return _meshgrid(*tensors) else: def meshgrid(*tensors): - r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional - vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by - expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs. + r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. + + This is helpful when you want to visualize data over some + range of inputs. See below for a plotting example. + + Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as + inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`, + this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots + G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where + the output :math:`G_i` is constructed by expanding :math:`T_i` + to the result shape. + + .. note:: + 0D inputs are treated equivalently to 1D inputs of a + single element. + + .. warning:: + `torch.meshgrid` has the same behavior as calling + `numpy.meshgrid(..., indexing='ij')`, and in the future + `torch.meshgrid` will also support the `indexing` + argument. + + https://github.com/pytorch/pytorch/issues/50276 tracks + this issue with the goal of migrating to NumPy's behavior. + + .. seealso:: + + :func:`torch.cartesian_prod` has the same effect but it + collects the data in a tensor of vectors. Args: tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be treated as tensors of size :math:`(1,)` automatically Returns: - seq (sequence of Tensors): If the input has :math:`k` tensors of size - :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also have :math:`k` tensors, - where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`. + seq (sequence of Tensors): If the input has :math:`N` + tensors of size :math:`S_0 \ldots S_{N-1}``, then the + output will also have :math:`N` tensors, where each tensor + is of shape :math:`(S_0, ..., S_{N-1})`. Example:: >>> x = torch.tensor([1, 2, 3]) >>> y = torch.tensor([4, 5, 6]) + + Observe the element-wise pairings across the grid, (1, 4), + (1, 5), ..., (3, 6). This is the same thing as the + cartesian product. >>> grid_x, grid_y = torch.meshgrid(x, y) >>> grid_x tensor([[1, 1, 1], @@ -356,6 +391,28 @@ def meshgrid(*tensors): tensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]]) + + This correspondence can be seen when these grids are + stacked properly. + >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), + ... torch.cartesian_prod(x, y)) + True + + `torch.meshgrid` is commonly used to produce a grid for + plotting. + >>> import matplotlib.pyplot as plt + >>> xs = torch.linspace(-5, 5, steps=100) + >>> ys = torch.linspace(-5, 5, steps=100) + >>> x, y = torch.meshgrid(xs, ys) + >>> z = torch.sin(torch.sqrt(x * x + y * y)) + >>> ax = plt.axes(projection='3d') + >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy()) + + >>> plt.show() + + .. image:: ../_static/img/meshgrid.png + :width: 512 + """ return _meshgrid(*tensors) @@ -512,7 +569,8 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False - since the signal isn't padded). + since the signal isn't padded). If `length` is given in the arguments and is longer than expected, + ``istft`` will pad zeros to the end of the returned signal. If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. Left padding can be trimmed off exactly because they can be calculated but right padding cannot be @@ -989,6 +1047,7 @@ def cartesian_prod(*tensors): [3, 4], [3, 5]]) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(cartesian_prod, tensors, *tensors) return _VF.cartesian_prod(tensors) # type: ignore[attr-defined] @@ -1023,6 +1082,7 @@ def block_diag(*tensors): [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]]) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(block_diag, tensors, *tensors) return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined] @@ -1110,6 +1170,7 @@ def atleast_1d(*tensors): >>> torch.atleast_1d((x,y)) (tensor([0.5000]), tensor([1.])) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(atleast_1d, tensors, *tensors) if len(tensors) == 1: @@ -1146,6 +1207,7 @@ def atleast_2d(*tensors): >>> torch.atleast_2d((x,y)) (tensor([[0.5000]]), tensor([[1.]])) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(atleast_2d, tensors, *tensors) if len(tensors) == 1: @@ -1190,6 +1252,7 @@ def atleast_3d(*tensors): >>> torch.atleast_3d((x,y)) (tensor([[[0.5000]]]), tensor([[[1.]]])) """ + # This wrapper exists to support variadic args. if has_torch_function(tensors): return handle_torch_function(atleast_3d, tensors, *tensors) if len(tensors) == 1: @@ -1426,6 +1489,7 @@ def chain_matmul(*matrices, out=None): .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition """ + # This wrapper exists to support variadic args. if has_torch_function(matrices): return handle_torch_function(chain_matmul, matrices, *matrices) diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 4ff795e632944..6524c2d1b8716 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -1,6 +1,4 @@ r''' -**This feature is under a Beta release and its API may change.** - FX is a toolkit for developers to use to transform ``nn.Module`` instances. FX consists of three main components: a **symbolic tracer,** an **intermediate representation**, and **Python code generation**. A @@ -28,12 +26,13 @@ def forward(self, x): # High-level intermediate representation (IR) - Graph representation print(symbolic_traced.graph) """ - graph(x): - %param : [#users=1] = self.param - %add_1 : [#users=1] = call_function[target=](args = (%x, %param), kwargs = {}) - %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) - %clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0}) - return clamp_1 + graph(): + %x : [#users=1] = placeholder[target=x] + %param : [#users=1] = get_attr[target=param] + %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp """ # Code generation - valid Python code @@ -41,10 +40,10 @@ def forward(self, x): """ def forward(self, x): param = self.param - add_1 = x + param; x = param = None - linear_1 = self.linear(add_1); add_1 = None - clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None - return clamp_1 + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp """ The **symbolic tracer** performs "symbolic execution" of the Python diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py new file mode 100644 index 0000000000000..2d33813200be2 --- /dev/null +++ b/torch/fx/_compatibility.py @@ -0,0 +1,34 @@ +from typing import Any, Dict +import textwrap + +_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} +_MARKED_WITH_COMATIBLITY : Dict[Any, None] = {} + +def compatibility(is_backward_compatible : bool): + if is_backward_compatible: + + def mark_back_compat(fn): + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMATIBLITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn): + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMATIBLITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 25f739e49f9ad..d38197322fab1 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -12,6 +12,7 @@ import torch.utils._pytree as pytree import sys +from ._compatibility import compatibility from .node import Argument, map_aggregate, base_types from .graph import Graph, _PyTreeInfo from .graph_module import GraphModule @@ -25,6 +26,7 @@ _proxyable_classes : Dict[Type, None] = {} +@compatibility(is_backward_compatible=True) class ProxyableClassMeta(type): """ ProxyableClassMeta allows you to make construction of a given Python class @@ -157,6 +159,7 @@ def __enter__(self): def __exit__(self, type, value, tb): sys.setprofile(None) +@compatibility(is_backward_compatible=False) class PHBase(object): """ Object representing an input placeholder to `concrete_args` @@ -166,6 +169,7 @@ def __repr__(self): PH = PHBase() +@compatibility(is_backward_compatible=True) class Tracer(TracerBase): # Reference: https://github.com/pytorch/pytorch/issues/54354 # The first line of this docstring overrides the one Sphinx generates for the @@ -182,6 +186,11 @@ class Tracer(TracerBase): process. The different behaviors that can be overridden are described in the docstrings of the methods on this class. """ + + # Not checking BC on this API because the default value for `autowrap_modules` + # includes the local filepath to the `math` module, which would jitter + # across machines. + @compatibility(is_backward_compatible=True) def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), autowrap_functions: Tuple[Callable, ...] = (), enable_cpatching: bool = False, @@ -197,11 +206,19 @@ def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, Python modules whose functions should be wrapped automatically - without needing to use fx.wrap(). + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. autowrap_function (Tuple[Callable, ...]): defaults to `()`, Python functions that should be wrapped automatically without - needing to use fx.wrap(). + needing to use fx.wrap(). Backward compabilibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluted directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. enable_cpatching (bool): defaults to `False`, Allows you to enable/disable monkeypatching of torch functions at the @@ -210,12 +227,9 @@ def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), C-level monkeypatching works by directly modifying the PyCFunctionObject* so that calling it returns a different function. - Turning this on is likely to slow down tracing by 1.5-3x. - - param_shapes_constant (bool): see https://github.com/pytorch/pytorch/issues/61733. When - this flag is set, calls to shape, size and a few other shape like attributes of a module's parameter - will be evaluted directly, rather than returning a new Proxy value for an attribute access. - + Turning this on is likely to slow down tracing by 1.5-3x. This + parameter is experimental and its backward-compatibility is NOT + guaranteed. """ super().__init__() @@ -235,6 +249,7 @@ def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ), self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + @compatibility(is_backward_compatible=True) def create_arg(self, a: Any) -> 'Argument': """ A method to specify the behavior of tracing when preparing values to @@ -325,6 +340,7 @@ def create_arg(self, a: Any) -> 'Argument': return super().create_arg(a) + @compatibility(is_backward_compatible=True) def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: """ A method to specify whether a given ``nn.Module`` is a "leaf" module. @@ -346,6 +362,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) + @compatibility(is_backward_compatible=True) def path_of_module(self, mod : torch.nn.Module) -> str: """ Helper method to find the qualified name of ``mod`` in the Module hierarchy @@ -372,6 +389,7 @@ def path_of_module(self, mod : torch.nn.Module) -> str: return n raise NameError('module is not installed as a submodule') + @compatibility(is_backward_compatible=True) def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: """ Method that specifies the behavior of this ``Tracer`` when it encounters @@ -404,6 +422,8 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tu return forward(*args, **kwargs) return self.create_proxy('call_module', module_qualified_name, args, kwargs) + # This method will be refactored + @compatibility(is_backward_compatible=False) def create_args_for_root(self, root_fn, is_module, concrete_args=None): """ Create ``placeholder`` nodes corresponding to the signature of the ``root`` @@ -509,8 +529,8 @@ def _module_getattr(self, attr, attr_val, parameter_proxy_cache): return attr_val - - def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + @compatibility(is_backward_compatible=True) + def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. @@ -524,8 +544,11 @@ def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[ Args: root (Union[Module, Callable]): Either a ``Module`` or a function to be - traced through. - concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies. + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. Returns: @@ -772,6 +795,7 @@ def _autowrap_check(patcher : _Patcher, frame_dict : Dict[str, Any], function_id patcher.patch(frame_dict, name, _create_wrapped_func(value)) +@compatibility(is_backward_compatible=True) def wrap(fn_or_name : Union[str, Callable]): """ This function can be called at module-level scope to register fn_or_name as a "leaf function". @@ -828,9 +852,11 @@ def my_custom_function(x, y): _wrapped_fns_to_patch.append((f.f_globals, fn_name)) return fn_or_name -def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None, +@compatibility(is_backward_compatible=True) +def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, enable_cpatching: bool = False) -> GraphModule: - """Symbolic tracing API + """ + Symbolic tracing API Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` constructed by recording operations seen while tracing through ``root``. @@ -876,7 +902,6 @@ def f(x): Returns: GraphModule: a Module created from the recorded operations from ``root``. - """ tracer = Tracer(enable_cpatching=enable_cpatching) graph = tracer.trace(root, concrete_args) diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index 6e0646a58ec52..032ce14b6ec70 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,6 +1,7 @@ from torch.fx.proxy import Proxy +from ._compatibility import compatibility - +@compatibility(is_backward_compatible=False) def annotate(val, type): # val could be either a regular value (not tracing) # or fx.Proxy (tracing) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 506bf2cdbec93..e101b6b7f22ff 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -414,6 +414,66 @@ def acc_ops_batch_norm(network, target, args, kwargs, name): return layer.get_output(0) +@tensorrt_converter(acc_ops.layer_norm) +def acc_ops_layer_norm(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"LayerNorm received input {input_val} that is not part " + "of the TensorRT region!") + + shape = kwargs["weight"].shape + broadcasted_shape = (1,) * (len(input_val.shape) - len(shape)) + shape + gamma = to_numpy(kwargs["weight"].reshape(*shape)) + beta = to_numpy(kwargs["bias"].reshape(*shape)) + eps = kwargs["eps"] + normalized_shape = kwargs["normalized_shape"] + + axes = 0 + for d in range(len(normalized_shape)): + axes |= 1 << (len(input_val.shape) - d - 1) + + # E[x] + mean_expected_layer = network.add_reduce(input_val, trt.ReduceOperation.AVG, axes, keep_dims=True) + mean_expected_layer.name = f"{name}_mean_expected" + # X-E[x] + sub_trt = add_binary_elementwise_layer( + network, input_val, mean_expected_layer.get_output(0), trt.ElementWiseOperation.SUB, f"{name}_sub" + ) + # Variance = mean(pow(x_sub_mean,2)) + pow_tensor = network.add_constant( + (1,) * len(input_val.shape), trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)) + ) + pow_tensor.name = f"{name}_power" + pow_var = add_binary_elementwise_layer( + network, sub_trt, pow_tensor.get_output(0), trt.ElementWiseOperation.POW, f"{name}_pow_var" + ) + mean_trt_layer = network.add_reduce(pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True) + mean_trt_layer.name = f"{name}_mean" + # Variance + eps + eps_tensor = network.add_constant( + (1,) * len(input_val.shape), trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)) + ) + eps_tensor.name = f"{name}_eps" + add_trt = add_binary_elementwise_layer( + network, mean_trt_layer.get_output(0), eps_tensor.get_output(0), trt.ElementWiseOperation.SUM, f"{name}_add" + ) + # SQRT((Var + eps)) + sqrt_trt = add_unary_layer(network, add_trt, trt.UnaryOperation.SQRT, f"{name}_sqrt") + # (x - E[x]) / sqrt((var + eps)) + div_trt = add_binary_elementwise_layer(network, sub_trt, sqrt_trt, trt.ElementWiseOperation.DIV, f"{name}_div_trt") + + gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) + gamma_tensor.name = f"{name}_gamma" + beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) + beta_tensor.name = f"{name}_beta" + # y * gamma + beta + scale_layer = add_binary_elementwise_layer( + network, div_trt, gamma_tensor.get_output(0), trt.ElementWiseOperation.PROD, f"{name}_scale" + ) + return add_binary_elementwise_layer( + network, scale_layer, beta_tensor.get_output(0), trt.ElementWiseOperation.SUM, name + ) @tensorrt_converter(acc_ops.softmax) def acc_ops_softmax(network, target, args, kwargs, name): @@ -657,6 +717,7 @@ def acc_ops_squeeze(network, target, args, kwargs, name): # dim, which is a very rare case. For now we just claim not supporting dim=None. assert dim is not None, "We don't support dim=None right now." + dim = dim % (len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)) if network.has_implicit_batch_dimension: assert dim != 0, "We don't support squeeze batch dim when it's implicit." dim -= 1 @@ -704,6 +765,11 @@ def acc_ops_mul(network, target, args, kwargs, name): network, kwargs["input"], kwargs["other"], trt.ElementWiseOperation.PROD, name ) +@tensorrt_converter(acc_ops.pow) +def acc_ops_pow(network, target, args, kwargs, name): + return add_binary_elementwise_layer( + network, kwargs["input"], kwargs["exponent"], trt.ElementWiseOperation.POW, name + ) @tensorrt_converter(acc_ops.min_two_tensors_input) def acc_ops_min_two_tensors_input(network, target, args, kwargs, name): @@ -731,6 +797,29 @@ def acc_ops_unsqueeze(network, target, args, kwargs, name): layer.name = name return layer.get_output(0) +@tensorrt_converter(acc_ops.topk) +def acc_ops_topk(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"topk received input {input_val} that is not part " + "of the TensorRT region!") + + if kwargs["sorted"] and kwargs["k"] != 1: + raise RuntimeError("Currently we don't support sorted=True in topk.") + + if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1: + raise RuntimeError("At least 2 dimensions are required for input to topk.") + + num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + k = kwargs["k"] + dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims + operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN + layer = network.add_topk( + input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension) + ) + layer.name = name + return (layer.get_output(0), layer.get_output(1)) @tensorrt_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name): @@ -842,6 +931,77 @@ def acc_ops_reshape(network, target, args, kwargs, name): layer.name = name return layer.get_output(0) +@tensorrt_converter(acc_ops.slice_tensor) +def acc_ops_slice_tensor(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"slice_tensor received input {input_val} that is not part " + "of the TensorRT region!") + + dims = kwargs["dims"] + if network.has_implicit_batch_dimension: + if not len(dims): + raise RuntimeError("dim argument cannot be empty!") + if any([dim == 0 for dim in dims]): + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dims}!" + ) + dims = [d - 1 for d in dims] + else: + raise RuntimeError("We don't support slice_tensor with explicit batch dimension yet!") + + start = [0] * len(input_val.shape) + stride = [1] * len(start) + output_shape = list(input_val.shape) + starts = kwargs["starts"] + stops = kwargs["stops"] + steps = kwargs["steps"] + + for i, dim in enumerate(dims): + start[dim] = starts[i] + stride[dim] = steps[i] + output_shape[dim] = (stops[i] - start[i]) // steps[i] + + layer = network.add_slice(input_val, start=start, shape=output_shape, stride=stride) + layer.name = name + return layer.get_output(0) + +@tensorrt_converter(acc_ops.split) +def acc_ops_split(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"split received input {input_val} that is not part " + "of the TensorRT region!") + + dim = kwargs["dim"] + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't split on batch dim when it's implicit!" + dim -= 1 + else: + raise RuntimeError("We don't support split with explicit batch dimension yet!") + + split_size = kwargs["split_size"] + start = [0] * len(input_val.shape) + stride = [1] * len(start) + offset = 0 + num_splits = (input_val.shape[dim] + split_size - 1) // split_size + if num_splits < 1: + raise RuntimeError(f"Invalid split: {input_val.shape[dim]} wuth split_size={split_size}") + + max_offset = input_val.shape[dim] + # add slice layers + output = [] + for i in range(num_splits): + shape = list(input_val.shape) + shape[dim] = min(split_size, max_offset - offset) + start[dim] = offset + layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) + offset += split_size + layer.name = f"{name}_{i}" + output.append(layer.get_output(0)) + return output @tensorrt_converter(acc_ops.linear) def acc_ops_linear(network, target, args, kwargs, name): @@ -859,13 +1019,42 @@ def acc_ops_linear(network, target, args, kwargs, name): "dim for linear and it can't be the last dim." ) - # add matrix multiply and add - weight = get_trt_tensor(network, kwargs["weight"], f"{name}_linear_weight", squeeze_vector=False) - output = add_matrix_multiply_layer(network, input_val, weight, f"{name}_linear_mm", transpose_other=True) - if kwargs["bias"] is not None: - return add_binary_elementwise_layer(network, output, kwargs["bias"], trt.ElementWiseOperation.SUM, f"{name}_linear_add") + weight = kwargs["weight"] + + # For quantization, weight here would be a trt tensor because it goes through + # quant + dequant. In this case, we need to use matmul + add because fully_connected + # can't take non-constant weight. + # TODO: Need to benchmark the performance of lowering linear as fully_connected versus + # lowering as matmul + add. TensorRT documentation suggests to always lower it as + # matmul + add but we found in some cases this results in performance regression compared + # with lowering to fully_connected layer. + if isinstance(weight, torch.Tensor): + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(input_val.shape) + (1, 1) + layer.name = f"{name}_pre_shuffle" + + # add fully connected + layer = network.add_fully_connected( + input=layer.get_output(0), + num_outputs=kwargs["weight"].shape[0], + kernel=to_numpy(kwargs["weight"]), + bias=to_numpy(kwargs["bias"]), + ) + layer.name = f"{name}_linear" + + # reshape back + layer = network.add_shuffle(layer.get_output(0)) + layer.reshape_dims = tuple(input_val.shape[:-1]) + (kwargs["weight"].shape[0],) + layer.name = f"{name}_post_shuffle" + + return layer.get_output(0) else: - return output + # add matrix multiply and add + output = add_matrix_multiply_layer(network, input_val, weight, f"{name}_linear_mm", transpose_other=True) + if kwargs["bias"] is not None: + return add_binary_elementwise_layer(network, output, kwargs["bias"], trt.ElementWiseOperation.SUM, f"{name}_linear_add") + else: + return output def add_clamp(network, input, val, op): @@ -909,6 +1098,15 @@ def acc_ops_clamp(network, target, args, kwargs, name): return input_val +@tensorrt_converter(acc_ops.tuple_construct) +def acc_ops_tuple_construct(network, target, args, kwargs, name): + return kwargs["tensors"] + + +@tensorrt_converter(acc_ops.contiguous) +def acc_ops_contiguous(network, target, args, kwargs, name): + return kwargs["input"] + @tensorrt_converter(acc_ops.getitem) def acc_ops_getitem(network, target, args, kwargs, name): @@ -951,7 +1149,7 @@ def slice_to_trt_params(py_slice, dim_size): batch_subscript = slices[0] if batch_subscript != slice(None, None, None): raise RuntimeError( - f"Can't subscript batch dimension when it's implicit. Got {slices}" + f"{name}: Can't subscript batch dimension when it's implicit. Got {slices}" ) # Remove batch_dim subscript @@ -1106,3 +1304,60 @@ def acc_ops_permute(network, target, args, kwargs, name): layer.second_transpose = tuple(permutation) layer.name = name return layer.get_output(0) + +@tensorrt_converter(acc_ops.quantize_per_tensor) +def acc_ops_quantize_per_tensor(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"{name} received input {input_val} that is not part " + "of the TensorRT region!") + + q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_scale") + q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "q_zero_point") + dtype = acc_utils.get_field_from_acc_out_ty(kwargs["acc_out_ty"], "dtype") + if dtype not in (torch.quint8, torch.qint8, torch.qint32): + raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) " + f"quantized type in quantize_per_tensor, get {dtype}.") + + if q_zero_point != 0: + raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") + + scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32))) + scale_layer.name = input_val.name + ".quant.scale" + scale = scale_layer.get_output(0) + # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " + # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + layer = network.add_quantize(input=input_val, scale=scale) + layer.axis = 0 + layer.name = input_val.name + ".quant" + return layer.get_output(0) + +@tensorrt_converter(acc_ops.dequantize) +def acc_ops_dequantize(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"{name} received input {input_val} that is not part " + "of the TensorRT region!") + + q_scale = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_scale") + q_zero_point = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "q_zero_point") + dtype = acc_utils.get_field_from_acc_out_ty(kwargs["input_tensor_meta"], "dtype") + + if dtype not in (torch.quint8, torch.qint8, torch.qint32): + raise RuntimeError("Only support (torch.quint8, torch.qint8, torch.qint32) " + f"quantized type in dequantize, get {dtype}.") + + if q_zero_point != 0: + raise RuntimeError(f"Only support zero_point == 0, get {q_zero_point}") + + scale_layer = network.add_constant((1,), trt.Weights(np.ascontiguousarray([q_scale], dtype=np.float32))) + scale_layer.name = input_val.name + ".dequant.scale" + scale = scale_layer.get_output(0) + # assert trt.__version__ > "8.0", "Explicit dequantize op is only supported in " + # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ + layer = network.add_dequantize(input=input_val, scale=scale) + layer.name = input_val.name + ".dequant" + layer.axis = 0 + return layer.get_output(0) diff --git a/torch/fx/experimental/fx2trt/example/fx2trt_example.py b/torch/fx/experimental/fx2trt/example/fx2trt_example.py index fff539d3bbe99..76bf69a181ad6 100644 --- a/torch/fx/experimental/fx2trt/example/fx2trt_example.py +++ b/torch/fx/experimental/fx2trt/example/fx2trt_example.py @@ -236,7 +236,7 @@ def _find_culprit(self, mod, inputs): # Assert results are equal with the original model. rn18 = rn18.cuda() - torch.testing.assert_allclose(split_mod(x), rn18(x)) + torch.testing.assert_close(split_mod(x), rn18(x)) import time NITER = 100 diff --git a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py new file mode 100644 index 0000000000000..140f4fb50bd76 --- /dev/null +++ b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py @@ -0,0 +1,117 @@ +import torch.fx +import torchvision.models as models +from torch.fx.experimental.fx2trt.fx2trt import TRTInterpreter, InputTensorSpec, TRTModule +from torch.quantization.quantize_fx import prepare_fx, convert_fx +import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer +import copy +from torch.fx.passes import shape_prop +from torch.fx.experimental.normalize import NormalizeArgs + +rn18 = models.resnet18().eval() + +def build_fp16_trt(rn18): + rn18 = copy.deepcopy(rn18) + rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) # type: ignore[attr-defined] + interp = TRTInterpreter(rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)]) + engine, input_names, output_names = interp.run(fp16_mode=True) + return TRTModule(engine, input_names, output_names) + +@torch.no_grad() +def build_int8_trt(rn18): + rn18 = copy.deepcopy(rn18) + data = torch.randn(1, 3, 224, 224) + # data = torch.randn(1, 64, 10, 10) + # TensorRT only supports symmetric quantization + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 + ), + weight=torch.quantization.default_weight_observer + ) + prepared = prepare_fx(rn18, {"": qconfig}) + for _ in range(10): + prepared(data) + quantized_rn18 = convert_fx(prepared, is_reference=True) + print("quantized model:", quantized_rn18) + + quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[attr-defined] + interp = TRTInterpreter(quantized_rn18, [InputTensorSpec(data.shape[1:], torch.float, has_batch_dim=False)]) + engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True) + return TRTModule(engine, input_names, output_names) + +@torch.no_grad() +def build_int8_trt_implicit_quant(rn18): + rn18 = copy.deepcopy(rn18) + data = torch.randn(1, 3, 224, 224) + # Quantization + qconfig = torch.quantization.QConfig( + activation=torch.quantization.observer.HistogramObserver.with_args( + qscheme=torch.per_tensor_symmetric, reduce_range=True + ), + weight=torch.quantization.default_per_channel_weight_observer + ) + prepared = prepare_fx(rn18, {"": qconfig}) + for _ in range(10): + prepared(data) + quantized_rn18 = convert_fx(prepared, is_reference=True) + + # Build trt int8 model + traced_rn18 = torch.fx.symbolic_trace(quantized_rn18) + shape_prop.ShapeProp(traced_rn18).propagate(data) + traced_rn18 = NormalizeArgs(traced_rn18).transform() + interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data])) + engine, input_names, output_names = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True) + trt_mod = TRTModule(engine, input_names, output_names) + return trt_mod + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3, padding=1) + + def forward(self, x): + out = self.conv(x) + # out = torch.nn.functional.relu(out) + out += x + out += out + out = torch.nn.functional.relu(out) + return out + +# rn18 = M().eval() +# rn18 = rn18.layer1 +int8_trt = build_int8_trt(rn18) +implicit_int8_trt = build_int8_trt_implicit_quant(rn18) +fp16_trt = build_fp16_trt(rn18) +x = torch.randn(5, 3, 224, 224, device="cuda") +rn18 = rn18.cuda() + +import time +NITER = 100 + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + fp16_trt(x) + torch.cuda.synchronize() +print('trt fp16 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + int8_trt(x) + torch.cuda.synchronize() +print('trt int8 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + implicit_int8_trt(x) + torch.cuda.synchronize() +print('trt implicit int8 time (ms/iter)', (time.time() - s) / NITER * 1000) + +torch.cuda.synchronize() +s = time.time() +for _ in range(NITER): + rn18(x) + torch.cuda.synchronize() +print('PyTorch time (ms/iter)', (time.time() - s) / NITER * 1000) diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index 160b4a7317a69..4c0b44c83085f 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -1,9 +1,10 @@ import warnings -from typing import List, NamedTuple, Iterable, Any, Optional, Tuple +from typing import List, NamedTuple, Iterable, Any, Optional, Tuple, Sequence import tensorrt as trt import torch import torch.fx +from torch.fx.node import _get_qualified_name # Borrowed from torch2trt @@ -52,6 +53,12 @@ def __init__( # Indicate output is in fp16 self.fp16_output = fp16_output + # Indices of outputs into the CUDA engine bindings, in the order as they are + # in the fx graph's `output` node. + self.output_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.output_names + ] + def _on_state_dict(self, state_dict, prefix, local_metadata): state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) state_dict[prefix + "input_names"] = self.input_names @@ -86,6 +93,7 @@ def forward(self, *inputs): bindings: List[Any] = [None] * (len(self.input_names) + len(self.output_names)) for i, input_name in enumerate(self.input_names): + assert inputs[i].is_cuda, f"{i}th input is not on cuda device." idx = self.engine.get_binding_index(input_name) bindings[idx] = contiguous_inputs[i].data_ptr() @@ -94,7 +102,7 @@ def forward(self, *inputs): # create output tensors outputs: List[torch.Tensor] = [] - for idx in range(len(inputs), len(inputs) + len(self.output_names)): + for idx in self.output_indices_in_order: dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) if self.engine.has_implicit_batch_dimension: @@ -225,6 +233,11 @@ def __init__( else: self.network = self.builder.create_network() + missing_ops = self.validate_conversion() + if missing_ops: + warnings.warn("Interpretation will fail due to missing operations \n" + + "\n".join(f"{i}" for i in missing_ops)) + self.optimization_profiles: Optional[List] = None self.input_specs = input_specs self.input_specs_iter = 0 @@ -290,6 +303,22 @@ def validate_input_specs(self): len(shape_ranges) == 0 ), "shape_ranges are provided for input that doesn't have dynamic dim." + def validate_conversion(self): + missing_converter = set() + + for node in self.module.graph.nodes: + if node.op == "call_function" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") + elif node.op == "call_method" and not CONVERTERS.get(node.target): + missing_converter.add(f"{node.op} torch.Tensor.{node.target}") + elif node.op == "call_module": + submod = self.fetch_attr(node.target) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + if not CONVERTERS.get(submod_type): + missing_converter.add(f"{node.op} {torch.typename(submod_type)}") + + return missing_converter + def run( self, max_batch_size=64, @@ -356,12 +385,11 @@ def placeholder(self, target, args, kwargs): def call_module(self, target, args, kwargs): assert isinstance(target, str) submod = self.fetch_attr(target) - converter = CONVERTERS.get(type(submod)) + submod_type = getattr(submod, "_base_class_origin", type(submod)) + converter = CONVERTERS.get(submod_type) if not converter: - raise RuntimeError( - f"Conversion of module of type {type(submod)} not currently supported!" - ) + raise RuntimeError(f'Conversion of module of type {submod_type} not currently supported!') return converter(self.network, submod, args, kwargs, self._cur_node_name) @@ -397,8 +425,6 @@ def output(self, target, args, kwargs): name = f"output{i}" output.name = name self.network.mark_output(output) - if self.fp16_mode: + if self.fp16_mode and output.dtype == trt.float32: output.dtype = trt.float16 - else: - output.dtype = trt.float32 self._output_names.append(name) diff --git a/torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py b/torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py new file mode 100644 index 0000000000000..bfddab57c0935 --- /dev/null +++ b/torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 + +import operator +import typing as t +import logging +import torch.fx as fx +import dataclasses as dc + + +_LOGGER = logging.getLogger(__name__) + + +def remove_duplicate_output_args( + top_level: fx.GraphModule, + target_subnets: t.Collection[str] +) -> t.Mapping[str, "RemoveDuplicateResult"]: + """Removes duplicate output args. + + This pass removes duplicate output args from the target subnets and fixes + their uses in the top level module where the subnets are called. This pass + must be called after acc split on the top-level net and subsequent calls to + the acc trace on the subnets. + + This pass will change both the subnets and top level module. + + Returns: + a mapping of the target subnet name to its dedupcate result + """ + + processed_subnets = {} + for node in top_level.graph.nodes: # type: fx.Node + if node.op == "call_module" and node.name in target_subnets: + assert isinstance(node.target, str) + sub_gm = top_level.get_submodule(node.target) + assert isinstance(sub_gm, fx.GraphModule) + + replace_res = _remove_duplicate_output_args(sub_gm) + processed_subnets[node.name] = replace_res + if replace_res.replacement_map is None: + continue + sub_gm.recompile() + + needs_recompile = False + # iterate on the copy since we will be changing elements of node.users + for user in list(node.users): + idx = _ensure_proper_output_use(user, node) + idx_new = replace_res.replacement_map[idx] + if idx_new != idx: + user.args = (user.args[0], idx_new) + needs_recompile = True + + if needs_recompile: + top_level.recompile() + return processed_subnets + + +@dc.dataclass(frozen=True) +class RemoveDuplicateResult: + replacement_map: t.Optional[t.List[int]] + module: fx.GraphModule + + +def _ensure_proper_output_use(user: fx.Node, target_node: fx.Node) -> int: + """ + Ensures the node looks in proper form of calling the output of an fx2trt + splitter sub-net. Specifically: + + 1. op is call function, target: operator.getitem + 2. args is a 2-element tuple + 3. args[0] is the name of the subnet's output + 4. args[1] is the index into the subnet output tuple + + E.g.: + + %getitem_4 : [#users=1] = call_function[target=operator.getitem](args = (%_run_on_acc_1, 4), kwargs = {}) + + returns the index into the subnet output tuple + """ + _LOGGER.info(f"Checking user node: {user.format_node()}") + assert ( + user.op == "call_function" + and user.target == operator.getitem + and len(user.args) == 2 + and isinstance(user.args[0], fx.Node) + and user.args[0].name == target_node.name + and isinstance(user.args[1], int) + ), f"Node is not a proper user of splitter output: {user.format_node()}" + + return user.args[1] + + +def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult: + output_nodes = [n for n in gm.graph.nodes if n.op == "output"] + assert len(output_nodes) == 1, \ + f"Expecting exactly one `output` node, but got {len(output_nodes)}" + + changed = False + # arg node name to its index in the new output args tuple + name_to_idx: t.Dict[str, int] = {} + output_node = output_nodes[0] + + # Output op only uses its `args[0]`, and it does not have `kwargs`. + # https://pytorch.org/docs/stable/fx.html#torch.fx.Node + args: t.Sequence[t.Any] = output_node.args[0] + + # Only concern outselves to the case where the args is an iterable of fx.Node. + # Other return cases (e.g., a single value) is possible and we don't handle + # that in this pass. + if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)): + return RemoveDuplicateResult(replacement_map=None, module=gm) + + # Map old index of the arg node to the remaining node's idx, + # initialized to `i => i` + replacement_map: t.List[int] = list(range(len(args))) + args_new = [] + for idx, a in enumerate(args): + assert isinstance(a, fx.Node), \ + f"Expecting fx.Node instance, but got: {type(a)}" + + if a.name not in name_to_idx: + args_new.append(a) + name_to_idx[a.name] = len(args_new) - 1 + else: + changed = True + _LOGGER.warning( + f"Replaced duplicate output arg '{a.name}': " + f"{idx} -> {name_to_idx[a.name]}" + ) + replacement_map[idx] = name_to_idx[a.name] + + output_node.args = (tuple(args_new),) + if changed: + gm.recompile() + return RemoveDuplicateResult(replacement_map, module=gm) diff --git a/torch/fx/experimental/fx2trt/tools/graph_util.py b/torch/fx/experimental/fx2trt/tools/graph_util.py new file mode 100644 index 0000000000000..96c8b12915da4 --- /dev/null +++ b/torch/fx/experimental/fx2trt/tools/graph_util.py @@ -0,0 +1,64 @@ +import graphviz # type: ignore[import] + +def get_layer_name_type(layer): + return "\n".join(f"{i}" for i in [layer.name, layer.type]) + +def trt_network_to_dot_graph(network): + dot = graphviz.Digraph(comment="Network") + + # add nodes (layers) + for i in range(network.num_layers): + layer = network.get_layer(i) + dot.node(get_layer_name_type(layer)) + + # add nodes (inputs) + for i in range(network.num_inputs): + dot.node(network.get_input(i).name) + + # add nodes (outputs) + for i in range(network.num_outputs): + dot.node(network.get_output(i).name) + + # add layer->layer edges + for a in range(network.num_layers): + layer_a = network.get_layer(a) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for i in range(layer_a.num_outputs): + output_i = layer_a.get_output(i) + + for j in range(layer_b.num_inputs): + input_j = layer_b.get_input(j) + + if output_i == input_j: + dot.edge(get_layer_name_type(layer_a), get_layer_name_type(layer_b), label=str(input_j.shape)) + + # add input->layer edges + for i in range(network.num_inputs): + input_i = network.get_input(i) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for j in range(layer_b.num_inputs): + input_j = layer_b.get_input(j) + + if input_i == input_j: + dot.edge(input_i.name, get_layer_name_type(layer_b), label=str(input_j.shape)) + + # add layer->output edges + for i in range(network.num_outputs): + input_i = network.get_output(i) + + for b in range(network.num_layers): + layer_b = network.get_layer(b) + + for j in range(layer_b.num_outputs): + input_j = layer_b.get_output(j) + + if input_i == input_j: + dot.edge(get_layer_name_type(layer_b), input_i.name, label=str(input_j.shape)) + + return dot diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index bc4dfb3c4fe5f..b10d35edd5baa 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -10,7 +10,7 @@ register_acc_op_mapping, register_custom_acc_mapper_fn, ) -from torch.fx.passes.shape_prop import extract_tensor_metadata +from torch.fx.passes.shape_prop import _extract_tensor_metadata this_arg_is_optional = True @@ -95,6 +95,12 @@ def avg_pool2d( return nn.functional.avg_pool2d(**locals()) +@register_acc_op_mapping(op_and_target=("call_function", torch.sign)) +@register_acc_op +def sign(*, input): + return torch.sign(input) + + @register_acc_op def size(*, input): return input.size() @@ -162,6 +168,7 @@ def add(*, input, other): return input + other +@register_acc_op_mapping(op_and_target=("call_method", "unsqueeze")) @register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze)) @register_acc_op def unsqueeze(*, input, dim): @@ -222,6 +229,12 @@ def transpose(*, input, dim0, dim1): return torch.transpose(**locals()) +@register_acc_op_mapping(op_and_target=("call_method", "contiguous")) +@register_acc_op +def contiguous(*, input): + return input.contiguous() + + @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.softmax)) @register_acc_op def softmax(*, input, dim, dtype): @@ -462,10 +475,12 @@ def quantize_per_tensor(*, input, acc_out_ty=None): ) -@register_acc_op_mapping(op_and_target=("call_function", torch.dequantize)) -@register_acc_op_mapping(op_and_target=("call_method", "dequantize")) @register_acc_op -def dequantize(*, input): +def dequantize(*, input, input_tensor_meta): + """ `input_tensor_meta` contains extra argument of quantization + parameters, e.g. scale/zero_point and will be using for + lowring dequantize op to TensorRT + """ return torch.dequantize(input) @@ -487,6 +502,12 @@ def div(*, input, other): return input / other +@register_acc_op_mapping(op_and_target=("call_function", torch.pow)) +@register_acc_op +def pow(*, input, exponent): + return torch.pow(input, exponent) + + @register_acc_op_mapping(op_and_target=("call_function", nn.functional.relu)) @register_acc_op_mapping( op_and_target=("call_function", torch.relu), @@ -500,6 +521,21 @@ def div(*, input, other): def relu(*, input, inplace=False): return nn.functional.relu(**locals()) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.log1p), + arg_replacement_tuples=[ + ("input", "input"), + ], +) +def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + with node.graph.inserting_before(node): + add_kwargs = {"input": node.kwargs["input"], "other": 1} + add_node = node.graph.call_function(add, kwargs=add_kwargs) + add_node.meta = node.meta.copy() + log_kwargs = {"input": add_node} + log_node = node.graph.call_function(log, kwargs=log_kwargs) + log_node.meta = node.meta.copy() + return log_node @register_custom_acc_mapper_fn( op_and_target=("call_method", "sum"), @@ -675,6 +711,57 @@ def batch_norm( def layer_norm(*, input, normalized_shape, weight, bias, eps): return nn.functional.layer_norm(**locals()) +def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node: + """ + Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk + + acc_ops.getitem + acc_ops.squeeze (depends on keepdim). + """ + input_node = node.kwargs["input"] + dim = node.kwargs["dim"] + keepdim = node.kwargs["keepdim"] + + if dim is None and keepdim: + raise RuntimeError("We currently don't support argmin/argmax with dim=None and keepdim=True") + + with node.graph.inserting_before(node): + if dim is None: + flatten_kwargs = {"input": node.kwargs["input"], "start_dim": 0, "end_dim": -1} + flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs) + flatten_node.meta["type"] = torch.Tensor + input_node = flatten_node + dim = -1 + + topk_kwargs = {"input": input_node, "k": 1, "dim": dim, "largest": largest, "sorted": False} + topk_node = node.graph.call_function(topk, kwargs=topk_kwargs) + # It's actually more like NamedTuple but tuple here should be fine. + topk_node.meta["type"] = tuple + + getitem_kwargs = {"input": topk_node, "idx": 1} + getitem_node = node.graph.call_function(getitem, kwargs=getitem_kwargs) + getitem_node.meta["type"] = torch.Tensor + output_node = getitem_node + + if not keepdim: + squeeze_kwargs = {"input": getitem_node, "dim": dim} + output_node = node.graph.call_function(squeeze, kwargs=squeeze_kwargs) + + output_node.meta = node.meta.copy() + return output_node + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.argmin), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("keepdim", "keepdim"), + ], +) +def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + """ + Map torch.argmin to acc_ops.flatten (depend on dim) + acc_ops.topk + acc_ops.getitem + + acc_ops.squeeze (depends on keepdim). + """ + return argmin_max_mapper_impl(node, largest=False) @register_custom_acc_mapper_fn( op_and_target=("call_method", "split"), @@ -871,6 +958,15 @@ def slice_tensor(*, input, dims, starts, stops, steps): ("length", "length"), ], ) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "narrow"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("start", "start"), + ("length", "length"), + ], +) def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: kwargs = { "input": node.kwargs["input"], @@ -1044,12 +1140,12 @@ def packed_quantized_linear_mapper( with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) - get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight()) + get_weight.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.weight()) get_bias = None if linear_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) - get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias()) + get_bias.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.bias()) # Create kwargs for acc_op.quantized_linear kwargs = { @@ -1092,12 +1188,12 @@ def packed_quantized_conv2d_mapper( with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) - get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight()) + get_weight.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.weight()) get_bias = None if conv_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) - get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias()) + get_bias.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.bias()) # Create kwargs for acc_op.conv kwargs = { @@ -1174,3 +1270,27 @@ def packed_quantized_convrelu2d_mapper( ) relu_node.meta = node.meta return relu_node + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.dequantize), + arg_replacement_tuples=[ + ("input", "input") + ] +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "dequantize"), + arg_replacement_tuples=[ + ("input", "input") + ] +) +def custom_dequantize_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + assert "tensor_meta" in node.kwargs["input"].meta + new_kwargs = {"input": node.kwargs["input"], "input_tensor_meta": node.kwargs["input"].meta["tensor_meta"]} + # `input_tensor_meta` contains quantization parameters that can be used to lower + # acc_ops.dequantize to TensorRT ops + with node.graph.inserting_before(node): + new_node = node.graph.create_node( + "call_function", dequantize, kwargs=new_kwargs, name=node.name + ) + new_node.meta = node.meta + return new_node diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index e3c1ce82d7a46..6094952f1695e 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -9,12 +9,18 @@ from torch.fx.experimental.refinement_types import Equality import itertools - from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +try: + import sympy # type: ignore[import] + HAS_SYMPY = True +except ImportError: + HAS_SYMPY = False + _INFERENCE_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: Dict[Target, Callable] = {} +_RULES: Dict[Target, Callable] = {} def expand_to_tensor_dim(t, n): @@ -22,7 +28,7 @@ def expand_to_tensor_dim(t, n): Expand a type to the desired tensor dimension if possible Raise an error otherwise. - t is the given type - - n is a number to expand to + - n is a number of dimensions to expand to """ if t == Dyn: dims = [Dyn] * n @@ -36,6 +42,13 @@ def expand_to_tensor_dim(t, n): def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): return t1, t2 @@ -46,7 +59,8 @@ def broadcast_types(t1, t2): new_t1 = list(t1.__args__) new_t2 = list(t2.__args__) - # here, we make our tensors the same length + # We make the types the same length which is the first requirement + # for consistency if s1 > s2: for i in range(s1 - s2): new_t2.insert(0, 1) @@ -55,17 +69,18 @@ def broadcast_types(t1, t2): for i in range(s2 - s1): new_t1.insert(0, 1) + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor for i, (x, y) in enumerate(zip(new_t1, new_t2)): if x == 1: new_t1[i] = y elif y == 1: new_t2[i] = x + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) - - if not is_consistent(t1, t2): - raise TypeError - return (t1, t2) else: raise TypeError(f'Cannot broadcast types {t1} and {t2}') @@ -73,7 +88,7 @@ def broadcast_types(t1, t2): def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError('Inference rule already registered for {call_target}!') + raise RuntimeError(f'Inference rule already registered for {call_target}!') _INFERENCE_RULES[call_target] = fn return fn return register @@ -81,15 +96,33 @@ def register(fn): def register_refinement_rule(call_target): def register(fn): if call_target in _REFINEMENT_RULES: - raise RuntimeError('Refinement rule already registered for {call_target}!') + raise RuntimeError(f'Refinement rule already registered for {call_target}!') _REFINEMENT_RULES[call_target] = fn return fn return register +def register_algebraic_expressions_inference_rule(call_target): + def register(fn): + if call_target in _RULES: + raise RuntimeError(f'Rule already registered for {call_target}!') + _RULES[call_target] = fn + return fn + return register @register_inference_rule(torch.add) @register_inference_rule(operator.add) def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) t1 = n.args[0].type @@ -100,10 +133,15 @@ def add_inference_rule(n: Node): n.type = t2 return n.type + # handle scalar addition elif t2 == int and isinstance(t1, TensorType): n.type = t1 return n.type + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: @@ -111,13 +149,13 @@ def add_inference_rule(n: Node): n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 - # Todo: maybe figure out that broadcasting definitely did not happen? else: n.meta['broadcast'] = False new_t1 = t1 if not n.meta['broadcast'] else new_t1 new_t2 = t2 if not n.meta['broadcast'] else new_t2 + # we check for consistency between the new types if is_consistent(new_t1, new_t2): # we return the less precise type because # broadcasting may have happened @@ -134,6 +172,12 @@ def add_inference_rule(n: Node): @register_inference_rule(getattr) def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ attr_node = n.args[0] attr_name = n.args[1] @@ -147,6 +191,10 @@ def get_attr_inference_rule(n: Node, traced): @register_inference_rule(torch.transpose) def transpose_inference_rule(n: Node): + """ + We check that dimentions for the transpose operations + are within range of the tensor type of the node + """ if n.target == torch.transpose: assert isinstance(n.args[0], Node) t = n.args[0].type @@ -160,12 +208,11 @@ def transpose_inference_rule(n: Node): return n.type elif isinstance(t, TensorType): - if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): new_type = list(t.__args__) new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] final = TensorType(new_type) - n.type = final + n.type = get_greatest_upper_bound(n.type, final) return n.type else: raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') @@ -175,6 +222,15 @@ def transpose_inference_rule(n: Node): @register_inference_rule(torch.reshape) def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ assert isinstance(n.args[0], Node) t1 = n.args[0].type @@ -190,7 +246,7 @@ def reshape_inference_rule(n: Node): # if any of the dimensions are unknown, # we check for divisibility - elif isinstance(t1, TensorType) and Dyn in t1.__args__ or -1 in t2: + elif isinstance(t1, TensorType): assert isinstance(t1, TensorType) a = [e if e != Dyn else 1 for e in t1.__args__] p1 = reduce(lambda x, y: x * y, a) @@ -200,17 +256,6 @@ def reshape_inference_rule(n: Node): return t2_type else: raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - - # if all dimensions are known we check the products - elif isinstance(t1, TensorType): - p1 = reduce(lambda x, y: x * y, t1.__args__) - p2 = reduce(lambda x, y: x * y, t2) - if p1 == p2: - n.type = t2_type - return t2_type - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - else: raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') @@ -249,7 +294,7 @@ def bn2d_inference_rule(n: Node, module_instance): def calculate_out_dimension(d_in, module_instance, index): """ - For calculating h_in and w_out. + For calculating h_in and w_out according to the conv2D documentation """ padding = (module_instance.padding, module_instance.padding) \ if isinstance(module_instance.padding, int) else module_instance.padding @@ -260,10 +305,12 @@ def calculate_out_dimension(d_in, module_instance, index): dilation = (module_instance.dilation, module_instance.dilation) \ if isinstance(module_instance.dilation, int) else module_instance.dilation + DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,) + if d_in == Dyn: return Dyn - elif isinstance(d_in, int): + elif isinstance(d_in, DIMENSION_TYPES): n = d_in + 2 * padding[index] - \ dilation[index] * \ (kernel_size[index] - 1) - 1 @@ -271,7 +318,7 @@ def calculate_out_dimension(d_in, module_instance, index): return (n // stride[0]) + 1 else: - raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn') + raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') def get_greatest_upper_bound(type1, type2): @@ -333,6 +380,10 @@ def relu_inference_rule(n: Node, module_instance): def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ new_type_list = list(typ.__args__) if len(new_type_list) == 4 or len(new_type_list) == 3: w_in = new_type_list[-1] @@ -378,7 +429,6 @@ def linear_check(tensor_type, module_instance): """ if len(tensor_type.__args__) >= 2: if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): - # Todo backwards propagation new_type_args = list(tensor_type.__args__) new_type_args[-1] = module_instance.out_features return TensorType(tuple(new_type_args)) @@ -390,6 +440,10 @@ def linear_check(tensor_type, module_instance): @register_inference_rule(torch.nn.Linear) def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ assert isinstance(n.args[0], Node) if n.args[0].type == Dyn and isinstance(n.type, TensorType): n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) @@ -438,7 +492,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance): def flatten_check(tensor_type, start_dim, end_dim): l = len(tensor_type.__args__) - start_dim = l if start_dim == -1 else start_dim + start_dim = l if start_dim == -1 else abs(start_dim) end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: @@ -457,6 +511,10 @@ def flatten_check(tensor_type, start_dim, end_dim): @register_inference_rule(torch.flatten) def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ assert isinstance(n.args[0], Node) # set the default start and end dims @@ -521,7 +579,7 @@ def type_check_node(self, n: Node): return n.type elif n.op == 'get_attr': - t = self.traced.get_parameter(n.target) + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] if isinstance(t.data, torch.Tensor): n.type = TensorType(t.data.shape) return n.type @@ -554,8 +612,25 @@ def get_node_type(a): @register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + @register_refinement_rule(torch.nn.Linear) -def first_one(n: Node): +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -563,11 +638,12 @@ def first_one(n: Node): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res -# todo needs review for addition. Is this constraint correct? @register_refinement_rule(BatchNorm2d) @register_refinement_rule(torch.nn.ReLU) -@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -577,19 +653,54 @@ def all_eq(n: Node): res = [Equality(args1[i], args2[i]) for i in range(len(args1))] return res + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + @register_refinement_rule(torch.add) @register_refinement_rule(operator.add) -def add_eq(n: Node): +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ res = [] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): args1, args2 = broadcast_types(arg_type1, arg_type2) - # by this point, we know for sure that args1 and args2 are the same size. + # by this point, we know that args1 and args2 are the same size. a1 = args1.__args__ a2 = args2.__args__ a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions r = [] for x, y, z in zip(a1, a2, a3): if x == y: @@ -597,19 +708,13 @@ def add_eq(n: Node): res = r return res -@register_refinement_rule(torch.nn.MaxPool2d) -def first_two(n: Node): - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] - return res @register_refinement_rule(torch.flatten) def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ assert isinstance(n.args[0], Node) eq_const = [] @@ -638,6 +743,24 @@ def flatten_refinement_rule(n: Node): eq_const.append(Equality(t1, t2)) return eq_const + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + class Refine: """ Symbolic shape inference. @@ -660,6 +783,15 @@ def refine(self): self.refine_node(n) return True + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + def replace_dyn_with_fresh_var(self, typ): """ Replace all unknown types with fresh type variables. @@ -670,6 +802,30 @@ def replace_dyn_with_fresh_var(self, typ): elif isinstance(typ, TensorType): new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if HAS_SYMPY: + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ else: return typ @@ -700,8 +856,70 @@ def refine_node(self, n: Node): pass if n.op == 'output': - assert isinstance(n.args[0], Node) - n.type = n.args[0].type + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + def infer_symbolic_relations(self, n: Node): + if HAS_SYMPY: + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == 'call_function': + if n.target in _RULES: + return _RULES[n.target](n) + else: + pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + else: + pass else: pass + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/torch/fx/experimental/graph_manipulation.py b/torch/fx/experimental/graph_manipulation.py index 9d0af5343ae9a..6daa000f609d1 100644 --- a/torch/fx/experimental/graph_manipulation.py +++ b/torch/fx/experimental/graph_manipulation.py @@ -412,7 +412,7 @@ def get_user_info(user_node: Argument) -> Any: def get_arg_info(arg: Argument) -> Any: if isinstance(arg, torch.fx.Node): return {"is_node": True, "name": str(arg)} - elif isinstance(arg, torch.dtype): + elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): return str(arg) else: return arg diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index b3f71d5de6cd2..de08ebaa69880 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -2,6 +2,7 @@ import inspect import textwrap import copy +import functools from types import FunctionType from typing import cast, Union, Callable, Dict, Optional, Any from torch.fx._symbolic_trace import Tracer @@ -41,8 +42,23 @@ def rewrite(self, fn: FunctionType): assert len(new_keys) == 1 fn_compiled = globals_dict[new_keys[0]] + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) + return g # Return the correct FunctionType object - return fn_compiled + return change_func_globals(fn_compiled, globals=fn.__globals__) def visit_Assert(self, node): """ diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index c8561041472ae..532d2784fb49a 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -2,11 +2,10 @@ from torch.fx.tensor_type import TensorType from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] + def infer_symbolic_types_single_pass(traced): """ - Generate constraints over types, - solve constraints with unification, - apply solution back to the types + Calls our symbolic inferencer once. """ r = Refine(traced) r.refine() @@ -20,8 +19,17 @@ def infer_symbolic_types(traced): to infer all the information such as the case for braodcasting. """ - infer_symbolic_types_single_pass(traced) - infer_symbolic_types_single_pass(traced) + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() def convert_eq(list_of_eq): """ @@ -52,6 +60,8 @@ def substitute_solution_one_type(mapping, t): if isinstance(t, Var): if t in mapping.keys(): return mapping[t] + else: + return t elif isinstance(t, TensorType): new_type = [] @@ -62,6 +72,21 @@ def substitute_solution_one_type(mapping, t): new_type.append(typ) return TensorType(tuple(new_type)) + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + def substitute_all_types(graph, mapping): """ diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 88c7b54a06ce4..65e93d0ccc7a1 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,6 +1,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name import torch.utils._pytree as pytree from . import _pytree as fx_pytree +from ._compatibility import compatibility from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type from dataclasses import dataclass @@ -175,9 +176,12 @@ def _is_illegal_name(self, name: str, obj: Any) -> bool: return False +@compatibility(is_backward_compatible=True) @dataclass class PythonCode: - """Represents all the information necessary to exec or save a graph as Python code.""" + """ + Represents all the information necessary to exec or save a graph as Python code. + """ # Python source code for the forward function definition. src: str # Values in global scope during exection of `src_def`. @@ -240,6 +244,7 @@ class _PyTreeInfo(NamedTuple): in_spec: pytree.TreeSpec out_spec: Optional[pytree.TreeSpec] +@compatibility(is_backward_compatible=True) class Graph: """ ``Graph`` is the main data structure used in the FX Intermediate Representation. @@ -283,6 +288,8 @@ def forward(self, x): For the semantics of operations represented in the ``Graph``, please see :class:`Node`. """ + + @compatibility(is_backward_compatible=True) def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None): """ Construct an empty Graph. @@ -299,6 +306,11 @@ def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Op @property def owning_module(self): + """ + Return the module that owns this ``GraphModule``, if there is one, + ``None`` if there is no owning module or if there are multiple owning + modules. + """ return self._owning_module @owning_module.setter @@ -322,6 +334,7 @@ def nodes(self) -> _node_list: """ return _node_list(self) + @compatibility(is_backward_compatible=True) def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': """ Copy all nodes from a given graph into ``self``. @@ -354,7 +367,7 @@ def __deepcopy__(self, memo=None) -> 'Graph': from the default implementation. This uses graph_copy to copy the nodes in an iterative way, rather than recursive. It also populates the memoization table to prevent unnecessary copies (e.g. references to - nodes or other parts of the Graph from a custom GraphModule implementation + nodes or other parts of the Graph from a custom GraphModule implementation. """ memo = memo if memo else {} g = Graph(tracer_cls=self._tracer_cls) @@ -364,6 +377,7 @@ def __deepcopy__(self, memo=None) -> 'Graph': g.output(output_val, type_expr=getattr(old_output_val, 'type', None)) return g + @compatibility(is_backward_compatible=True) def create_node(self, op: str, target: 'Target', args: Optional[Tuple['Argument', ...]] = None, kwargs: Optional[Dict[str, 'Argument']] = None, @@ -410,10 +424,12 @@ def create_node(self, op: str, target: 'Target', self._len += 1 return n + @compatibility(is_backward_compatible=False) def flatten_inps(self, *args): flat_args, args_spec = pytree.tree_flatten(args) return flat_args + @compatibility(is_backward_compatible=False) def unflatten_outs(self, out): if self._pytree_info is None: return out @@ -422,6 +438,7 @@ def unflatten_outs(self, out): assert(self._pytree_info.out_spec is not None) return pytree.tree_unflatten(out, self._pytree_info.out_spec) + @compatibility(is_backward_compatible=True) def erase_node(self, to_erase : Node) -> None: """ Erases a ``Node`` from the ``Graph``. Throws an exception if @@ -448,6 +465,7 @@ def erase_node(self, to_erase : Node) -> None: assert isinstance(new_kwargs, dict) to_erase.kwargs = new_kwargs + @compatibility(is_backward_compatible=True) def inserting_before(self, n: Optional[Node] = None): """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and @@ -470,6 +488,7 @@ def inserting_before(self, n: Optional[Node] = None): assert n.graph == self, "Node to insert before is not in graph." return _InsertPoint(self, n.prepend) + @compatibility(is_backward_compatible=True) def inserting_after(self, n: Optional[Node] = None): """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and @@ -492,7 +511,7 @@ def inserting_after(self, n: Optional[Node] = None): assert n.graph == self, "Node to insert after is not in graph." return _InsertPoint(self, n.append) - # sugar for create_node when you know the op + @compatibility(is_backward_compatible=True) def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: """ Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents @@ -514,6 +533,7 @@ def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: """ return self.create_node('placeholder', name, type_expr=type_expr) + @compatibility(is_backward_compatible=True) def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: """ Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the @@ -571,6 +591,7 @@ def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> boo "necessary buffer") return self.create_node('get_attr', qualified_name, type_expr=type_expr) + @compatibility(is_backward_compatible=True) def call_module(self, module_name: str, args: Optional[Tuple['Argument', ...]] = None, @@ -615,6 +636,7 @@ def call_module(self, "necessary submodule") return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + @compatibility(is_backward_compatible=True) def call_method(self, method_name: str, args: Optional[Tuple['Argument', ...]] = None, @@ -649,6 +671,7 @@ def call_method(self, """ return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + @compatibility(is_backward_compatible=True) def call_function(self, the_function: Callable[..., Any], args: Optional[Tuple['Argument', ...]] = None, @@ -684,6 +707,7 @@ def call_function(self, """ return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + @compatibility(is_backward_compatible=True) def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: """ Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from @@ -714,6 +738,7 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = la result_node.meta = copy.copy(node.meta) return result_node + @compatibility(is_backward_compatible=True) def output(self, result: 'Argument', type_expr: Optional[Any] = None): """ Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents @@ -745,6 +770,7 @@ def _target_to_str(self, target : Target) -> str: op = _snake_case(op) return op + @compatibility(is_backward_compatible=True) def python_code(self, root_module: str) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -923,11 +949,13 @@ def emit_node(node : Node): return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value if global_name == 'getattr' and \ isinstance(node.args, tuple) and \ isinstance(node.args[1], str) and \ - node.args[1].isidentifier(): - # pretty print attribute access + node.args[1].isidentifier() and \ + len(node.args) == 2: body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') return body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') @@ -995,7 +1023,7 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: def __str__(self) -> str: """ - Print a human-readable (not machine-readable) string representation + Return a human-readable (not machine-readable) string representation of this Graph """ placeholder_names : List[str] = [] @@ -1011,10 +1039,12 @@ def __str__(self) -> str: s += '\n ' + node_str return s + @compatibility(is_backward_compatible=True) def print_tabular(self): """ Prints the intermediate representation of the graph in tabular - format. + format. Note that this API requires the ``tabulate`` module to be + installed. """ try: from tabulate import tabulate @@ -1027,6 +1057,7 @@ def print_tabular(self): print(tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + @compatibility(is_backward_compatible=True) def lint(self): """ Runs various checks on this Graph to make sure it is well-formed. In @@ -1066,8 +1097,15 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: # Check targets are legit if self.owning_module: for node in self.nodes: + if node.op == 'call_function': + if not callable(node.target): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a Callable is expected') + else: + if not isinstance(node.target, str): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a str is expected') if node.op in ['get_attr', 'call_module']: - assert isinstance(node.target, str) target_atoms = node.target.split('.') m_itr = self.owning_module for i, atom in enumerate(target_atoms): @@ -1090,6 +1128,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: else: m_itr = new_m_itr + @compatibility(is_backward_compatible=True) def eliminate_dead_code(self): """ Remove all dead code from the graph, based on each node's number of @@ -1117,7 +1156,6 @@ def forward(self, x): def forward(self, x): return x + self.attr_1 - """ # Lint the graph first to make sure its topologically sorted, otherwise # DCE below will not behave as expected. diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 0cbbd9373027a..ca82d49e07cbe 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -6,6 +6,7 @@ import linecache from typing import Type, Dict, List, Any, Union, Optional, Set from .graph import Graph, _is_from_torch, _custom_builtins, PythonCode +from ._compatibility import compatibility from torch.package import Importer, sys_importer import copy import itertools @@ -15,33 +16,65 @@ import os import warnings -# normal exec loses the source code, however we can patch -# the linecache module to still recover it. -# using exec_with_source will add it to our local cache +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache # and then tools like TorchScript will be able to get source info. -_next_id = 0 -def exec_with_source(src: str, globals: Dict[str, Any]): - global _next_id - key = f'' - _next_id += 1 - _eval_cache[key] = [line + '\n' for line in src.splitlines()] - exec(compile(src, key, 'exec'), globals) +class _EvalCacheLoader(object): + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: Dict[str, Any]): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy['__file__'] = key + globals_copy['__name__'] = key + globals_copy['__loader__'] = self + linecache.lazycache(key, globals_copy) + + return key -# patch linecache so that any code we exec using exec_with_source -# works with inspect -_eval_cache : Dict[str, List[str]] = {} -_orig_getlines = linecache.getlines -def patched_getline(*args, **kwargs): - if args[0] in _eval_cache: - return _eval_cache[args[0]] - return _orig_getlines(*args, **kwargs) -linecache.getlines = patched_getline + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f'.{self.next_id}' + self.next_id += 1 + return key + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: Dict[str, Any]): + key = _loader.cache(src, globals) + exec(compile(src, key, 'exec'), globals) def _forward_from_src(src: str, globals: Dict[str, Any]): # avoid mutating the passed in dict globals_copy = globals.copy() - exec_with_source(src, globals_copy) + _exec_with_source(src, globals_copy) forward_fn = globals_copy['forward'] del globals_copy['forward'] return forward_fn @@ -63,6 +96,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer): return '\n'.join(import_strs) +@compatibility(is_backward_compatible=True) def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: # BC: attribute name was changed from `code` to `_code` to facilitate # making `code` into a property and adding a docstring to it @@ -71,13 +105,14 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod return _deserialize_graph_module(forward, body) +@compatibility(is_backward_compatible=True) def reduce_package_graph_module( importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str ) -> torch.nn.Module: forward = importer.import_module(generated_module_name).forward return _deserialize_graph_module(forward, body) - +@compatibility(is_backward_compatible=True) def reduce_deploy_graph_module( importer: PackageImporter, body: Dict[Any, Any], import_block: str ) -> torch.nn.Module: @@ -187,6 +222,7 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): else: setattr(to_module, field, from_obj) +@compatibility(is_backward_compatible=True) class GraphModule(torch.nn.Module): """ GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a @@ -199,7 +235,6 @@ class GraphModule(torch.nn.Module): regenerated. However, if you edit the contents of the ``graph`` without reassigning the ``graph`` attribute itself, you must call ``recompile()`` to update the generated code. - """ def __new__(cls: 'Type[GraphModule]', *args, **kwargs): # each instance of a graph module needs its own forward method @@ -207,10 +242,19 @@ def __new__(cls: 'Type[GraphModule]', *args, **kwargs): # it is a subclass of the user-defined class, the only difference # is an extra layer to install the forward method + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split('.')[-1] + if c != 'GraphModuleImpl': + cls = t + break + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] pass return super().__new__(GraphModuleImpl) + @compatibility(is_backward_compatible=True) def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, @@ -234,7 +278,6 @@ def __init__(self, class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all error messages will report as originating from ``GraphModule``. It may be helpful to set this to ``root``'s original name or a name that makes sense within the context of your transform. - """ super().__init__() self.__class__.__name__ = class_name @@ -302,6 +345,7 @@ def graph(self, g : Graph) -> None: g.owning_module = self self.recompile() + @compatibility(is_backward_compatible=False) def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"): """Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from import `` @@ -366,6 +410,7 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: warnings.warn("Was not able to save the following children modules as reprs -" f"saved as pickled files instead: {blobified_modules}") + @compatibility(is_backward_compatible=True) def add_submodule(self, target: str, m: torch.nn.Module) -> bool: """ Adds the given submodule to ``self``. @@ -386,7 +431,6 @@ def add_submodule(self, target: str, m: torch.nn.Module) -> bool: denoted by ``target`` must either a) not exist yet, or b) reference an ``nn.Module`` (not a parameter or other attribute) - """ *prefix, field = target.split('.') mod: torch.nn.Module = self @@ -407,6 +451,7 @@ def add_submodule(self, target: str, m: torch.nn.Module) -> bool: mod.add_module(field, m) return True + @compatibility(is_backward_compatible=True) def delete_submodule(self, target: str) -> bool: """ Deletes the given submodule from ``self``. @@ -449,6 +494,7 @@ def delete_submodule(self, target: str) -> bool: delattr(mod, target_submod) return True + @compatibility(is_backward_compatible=True) def delete_all_unused_submodules(self) -> None: """ Deletes all unused submodules from ``self``. @@ -503,6 +549,7 @@ def code(self) -> str: raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') return self._code + @compatibility(is_backward_compatible=True) def recompile(self) -> PythonCode: """ Recompile this GraphModule from its ``graph`` attribute. This should be @@ -539,7 +586,7 @@ def generate_error_message(frame_summary: traceback.FrameSummary) -> str: # auxiliary variables (for readability) err_lineno = frame_summary.lineno err_line_len = len(frame_summary.line) - all_src_lines = _eval_cache[frame_summary.filename] + all_src_lines = linecache.getlines(frame_summary.filename) # constituent substrings of the error message tb_repr = traceback.format_exc() @@ -615,7 +662,7 @@ def __reduce__(self): def __deepcopy__(self, memo): fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(self.__dict__) - return GraphModule(fake_mod, self.graph) + return GraphModule(fake_mod, fake_mod.__dict__['_graph']) def __copy__(self): return GraphModule(self, self.graph) @@ -624,6 +671,11 @@ def __str__(self) -> str: orig_str = super().__str__() return '\n'.join([orig_str, self._code]) + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + # workarounds for issues in __torch_function__ # WAR for __torch_function__ not handling tensor lists, diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 459c30e745dfd..1093a07c8d229 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,3 +1,4 @@ +from ._compatibility import compatibility _help_mutation = """\ If you are attempting to modify the kwargs or args of a torch.fx.Node object, @@ -20,5 +21,8 @@ def _create_immutable_container(base, mutable_functions): 'clear', 'extend', 'insert', 'pop', 'remove']) immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) +compatibility(is_backward_compatible=True)(immutable_list) + immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update']) immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) +compatibility(is_backward_compatible=True)(immutable_dict) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 20dcf62e0c3cb..64233b4cf18b6 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -3,8 +3,10 @@ from .node import Argument, Node, Target, map_arg, map_aggregate from .proxy import Proxy from ._symbolic_trace import Tracer +from ._compatibility import compatibility from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +@compatibility(is_backward_compatible=True) class Interpreter: """ An Interpreter executes an FX graph Node-by-Node. This pattern @@ -59,6 +61,7 @@ def fn(x): execution. This can be disabled to, for example, examine all of the intermediate values in the execution by looking at the ``Interpreter.env`` attribute. """ + @compatibility(is_backward_compatible=True) def __init__(self, module : GraphModule, garbage_collect_values : bool = True): assert isinstance(module, GraphModule) self.module = module @@ -84,6 +87,7 @@ def register_last_uses(n : Node, user : Node): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + @compatibility(is_backward_compatible=True) def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any: """ Run `module` via interpretation and return the result. @@ -123,6 +127,7 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any: output_val = self.env[node] return output_val + @compatibility(is_backward_compatible=True) def run_node(self, n : Node) -> Any: """ Run a specific node ``n`` and return the result. @@ -142,7 +147,7 @@ def run_node(self, n : Node) -> Any: return getattr(self, n.op)(n.target, args, kwargs) # Main Node running APIs - + @compatibility(is_backward_compatible=True) def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: @@ -168,6 +173,7 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D else: return next(self.args_iter) + @compatibility(is_backward_compatible=True) def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute @@ -186,6 +192,7 @@ def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict assert isinstance(target, str) return self.fetch_attr(target) + @compatibility(is_backward_compatible=True) def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the result. @@ -205,6 +212,7 @@ def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : # Execute the function and return the result return target(*args, **kwargs) + @compatibility(is_backward_compatible=True) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the result. @@ -226,6 +234,7 @@ def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D assert isinstance(target, str) return getattr(self_obj, target)(*args_tail, **kwargs) + @compatibility(is_backward_compatible=True) def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute a ``call_module`` node and return the result. @@ -248,6 +257,7 @@ def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D return submod(*args, **kwargs) + @compatibility(is_backward_compatible=True) def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves @@ -266,7 +276,7 @@ def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[s return args[0] # Helper methods - + @compatibility(is_backward_compatible=True) def fetch_attr(self, target : str): """ Fetch an attribute from the ``Module`` hierarchy of ``self.module``. @@ -285,6 +295,7 @@ def fetch_attr(self, target : str): attr_itr = getattr(attr_itr, atom) return attr_itr + @compatibility(is_backward_compatible=True) def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: """ Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` @@ -302,6 +313,7 @@ def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: assert isinstance(kwargs, dict) return args, kwargs + @compatibility(is_backward_compatible=True) def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: """ Recursively descend through ``args`` and look up the concrete value @@ -319,6 +331,7 @@ def load_arg(n_arg : Node) -> Any: return self.env[n_arg] return map_arg(args, load_arg) +@compatibility(is_backward_compatible=True) class Transformer(Interpreter): """ ``Transformer`` is a special type of interpreter that produces a @@ -357,6 +370,8 @@ def fn(x): Args: module (GraphModule): The ``Module`` to be transformed. """ + + @compatibility(is_backward_compatible=True) def __init__(self, module): super().__init__(module) self.new_graph = Graph() @@ -371,6 +386,7 @@ def is_leaf_module(self, _, __) -> bool: self.tracer = TransformerTracer(self.new_graph) self.tracer.root = module + @compatibility(is_backward_compatible=True) def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: """ Execute a ``placeholder`` node. In ``Transformer``, this is @@ -387,6 +403,7 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D assert isinstance(target, str) return Proxy(self.new_graph.placeholder(target), self.tracer) + @compatibility(is_backward_compatible=True) def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: """ Execute a ``get_attr`` node. In ``Transformer``, this is @@ -403,16 +420,19 @@ def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict assert isinstance(target, str) return Proxy(self.new_graph.get_attr(target), self.tracer) + @compatibility(is_backward_compatible=True) def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: # Override so that the leaf module policy from `self.tracer` is respected. assert isinstance(target, str) submod = self.fetch_attr(target) return self.tracer.call_module(submod, submod.forward, args, kwargs) + @compatibility(is_backward_compatible=True) def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: # Override so that functions that were wrapped are still wrapped. return self.tracer.create_proxy('call_function', target, args, kwargs) + @compatibility(is_backward_compatible=True) def transform(self) -> GraphModule: """ Transform ``self.module`` and return the transformed diff --git a/torch/fx/node.py b/torch/fx/node.py index e00f25f47a2ee..61dfba7acb03f 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -1,5 +1,6 @@ # Nodes represent a definition of a value in our graph of operators. from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set +from ._compatibility import compatibility from .immutable_collections import immutable_dict, immutable_list import torch import builtins @@ -85,6 +86,7 @@ def _format_arg(arg) -> str: else: return str(arg) +@compatibility(is_backward_compatible=True) class Node: """ ``Node`` is the data structure that represents individual operations within @@ -112,15 +114,49 @@ class Node: - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ + + @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], - type : Optional[Any] = None) -> None: + return_type : Optional[Any] = None) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ self.graph = graph self.name = name # unique name of value being created assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - if op in ['call_method', 'call_module']: - assert isinstance(target, str) + if op == 'call_function': + if not callable(target): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a Callable is expected') + else: + if not isinstance(target, str): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a str is expected') self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add @@ -146,7 +182,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = type + self.type : Optional[Any] = return_type self._prev = self self._next = self self._erased = False @@ -181,6 +217,7 @@ def prev(self) -> 'Node': """ return self._prev + @compatibility(is_backward_compatible=True) def prepend(self, x: 'Node') -> None: """ Insert x before this node in the list of nodes in the graph. Example:: @@ -199,6 +236,7 @@ def prepend(self, x: 'Node') -> None: p._next, x._prev = x, p x._next, self._prev = self, x + @compatibility(is_backward_compatible=True) def append(self, x: 'Node') -> None: """ Insert x after this node in the list of nodes in the graph. @@ -273,6 +311,7 @@ def all_input_nodes(self) -> List['Node']: """ return list(self._input_nodes.keys()) + @compatibility(is_backward_compatible=True) def update_arg(self, idx : int, arg : Argument) -> None: """ Update an existing positional argument to contain the new value @@ -287,6 +326,7 @@ def update_arg(self, idx : int, arg : Argument) -> None: args[idx] = arg self.args = tuple(args) + @compatibility(is_backward_compatible=True) def update_kwarg(self, key : str, arg : Argument) -> None: """ Update an existing keyword argument to contain the new value @@ -359,6 +399,7 @@ def _pretty_print_target(self, target): return f'operator.{target.__name__}' return _get_qualified_name(target) + @compatibility(is_backward_compatible=True) def format_node(self, placeholder_names: List[str] = None, maybe_return_typename: List[str] = None) -> Optional[str]: @@ -414,6 +455,7 @@ def format_node(self, f'{self.op}[target={self._pretty_print_target(self.target)}](' \ f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + @compatibility(is_backward_compatible=True) def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. @@ -443,6 +485,7 @@ def maybe_replace_node(n : Node) -> Node: assert len(self.users) == 0 return to_process + @compatibility(is_backward_compatible=False) def is_impure(self): """ Returns whether this op is impure, i.e. if its op is a placeholder or @@ -472,6 +515,7 @@ def is_impure(self): return False + @compatibility(is_backward_compatible=False) def normalized_arguments( self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, kwarg_types : Optional[Dict[str, Any]] = None, @@ -507,7 +551,7 @@ def normalized_arguments( return None - + @compatibility(is_backward_compatible=True) def replace_input_with(self, old_input: 'Node', new_input: 'Node'): """ Loop through input nodes of ``self``, and replace all instances of @@ -517,7 +561,6 @@ def replace_input_with(self, old_input: 'Node', new_input: 'Node'): old_input (Node): The old input node to be replaced. new_input (Node): The new input node to replace ``old_input``. - """ def maybe_replace_node(n : Node) -> Node: return new_input if n == old_input else n @@ -529,13 +572,19 @@ def maybe_replace_node(n : Node) -> Node: self.__update_args_kwargs(new_args, new_kwargs) +@compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: - """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) +@compatibility(is_backward_compatible=True) def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: - """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ if isinstance(a, tuple): return tuple(map_aggregate(elem, fn) for elem in a) elif isinstance(a, list): diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 5f61ebe718ff1..d7ddc3e0360c7 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -4,9 +4,14 @@ import typing import enum import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING from torch._jit_internal import boolean_dispatched +from ._compatibility import compatibility +if TYPE_CHECKING: + from .node import Argument + +@compatibility(is_backward_compatible=False) class ArgsKwargsPair(NamedTuple): """ Simple named tuple for wrapping args/kwargs pairs. @@ -76,7 +81,44 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins return inspect.Signature(parameters, return_annotation=return_type) -def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature]]: +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError as e: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' + f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' + f'are not supported') + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + pass + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): """ Given an operator on the `torch` namespace, return a list of `inspect.Signature` objects corresponding to the overloads of that op.. May return `None` if a signature @@ -87,22 +129,25 @@ def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature Returns: Optional[List[inspect.Signature]]: A list of signatures for the overloads of this - operator, or None if the operator signatures could not be retrieved. + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature """ override = _manual_overrides.get(op) if override: - return override + return (override, None) if return_schemas else None aten_fn = torch.jit._builtins._find_builtin(op) if aten_fn is None: - return None + return (None, None) if return_schemas else None schemas = torch._C._jit_get_schemas_for_operator(aten_fn) signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] - return signatures + return (signatures, schemas) if return_schemas else signatures +@compatibility(is_backward_compatible=False) def create_type_hint(x): try: if isinstance(x, list) or isinstance(x, tuple): @@ -130,6 +175,7 @@ def ret_type(x): pass return x +@compatibility(is_backward_compatible=False) def type_matches(signature_type : Any, argument_type : Any): sig_origin_type = getattr(signature_type, '__origin__', signature_type) @@ -177,6 +223,7 @@ def is_homogeneous_tuple(t): return False +@compatibility(is_backward_compatible=False) def normalize_function( target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, kwarg_types : Optional[Dict[str, Any]] = None, @@ -272,6 +319,7 @@ def normalize_function( return new_args_and_kwargs +@compatibility(is_backward_compatible=False) def normalize_module( root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 2a093bea49a4c..b7a911e4bf3db 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -1,5 +1,6 @@ import argparse from typing import Any, Callable, Tuple, Dict, Optional +import logging import torch import torch.fx @@ -17,6 +18,8 @@ Names ) +_LOGGER = logging.getLogger(__name__) + class FxNetMinimizerBadModuleError(Exception): """ @@ -403,6 +406,7 @@ def _sequential_traverse(self, nodes: NodeList) -> NodeSet: culprits: NodeSet = set() for node in nodes: + _LOGGER.info(f"Visit node: {node.name}") cur_nodes: NodeSet = {node} if node in self.fusions: diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 6f0f72d38c75f..816fbe7aaac6c 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -2,7 +2,9 @@ import torch.fx from torch.fx.node import Node, map_aggregate from typing import Any, Tuple, NamedTuple, Optional +from torch.fx._compatibility import compatibility +@compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): # TensorMetadata is a structure containing pertinent information # about a tensor within a PyTorch program. @@ -20,7 +22,7 @@ class TensorMetadata(NamedTuple): q_scale : Optional[float] q_zero_point : Optional[int] -def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: +def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. """ @@ -58,7 +60,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: return TensorMetadata( shape, dtype, requires_grad, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point) - +@compatibility(is_backward_compatible=True) class ShapeProp(torch.fx.Interpreter): """ Execute an FX graph Node-by-Node and @@ -113,7 +115,7 @@ def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): nonlocal found_tensor found_tensor = True - return extract_tensor_metadata(obj) + return _extract_tensor_metadata(obj) else: return obj diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 989ec92777cc3..c42af7e9c2d9b 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,7 +1,9 @@ import torch from torch.fx.graph_module import GraphModule from typing import Callable, List, Dict, Any, Optional +from torch.fx._compatibility import compatibility +@compatibility(is_backward_compatible=True) class Partition: def __init__(self, name: str): self.name: str = name @@ -23,6 +25,7 @@ def __repr__(self) -> str: f" parition dependents: {self.partition_dependents}" # Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) def split_module( m: GraphModule, root_m: torch.nn.Module, diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 42087bde9ef89..65419055dad82 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -2,6 +2,7 @@ from collections import defaultdict from dataclasses import dataclass from typing import List, Dict, Optional, Tuple +import logging import torch from torch.fx.experimental.graph_manipulation import get_size_of_node @@ -20,8 +21,12 @@ Tensors, NodeList, NodeSet, + is_node_output_tensor, ) +_LOGGER = logging.getLogger(__name__) + + class _SplitterSettingBase: def __init__(self): parser = argparse.ArgumentParser() @@ -98,7 +103,7 @@ def reduce_acc_nodes_non_tensor_input_helper( for user in node.users: if user in self.acc_nodes: self.acc_nodes.remove(user) - if "tensor_meta" not in user.meta: + if not is_node_output_tensor(user): cpu_worklist.append(user) def reduce_acc_nodes_non_tensor_input(self): @@ -113,7 +118,7 @@ def reduce_acc_nodes_non_tensor_input(self): continue if node in self.acc_nodes: continue - if "tensor_meta" in node.meta: + if is_node_output_tensor(node): continue non_tensor_cpu_nodes.append(node) @@ -128,7 +133,7 @@ def reduce_acc_nodes_non_tensor_output(self): new_cpu_nodes: NodeList = [] for acc_node in self.acc_nodes: - if "tensor_meta" in acc_node.meta: + if is_node_output_tensor(acc_node): continue for user in acc_node.users: if user not in self.acc_nodes: @@ -461,7 +466,7 @@ def get_inputs(self, inputs): reports += "Checking inputs...\n" for n in submod.graph.nodes: if n.op == "placeholder": - if "tensor_meta" not in n.meta: + if not is_node_output_tensor(n): reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" else: total_input_bytes += get_size_of_node(submod, n)[0] @@ -473,7 +478,7 @@ def get_inputs(self, inputs): def get_bytes(node: torch.fx.Node): nonlocal total_output_bytes nonlocal reports - if "tensor_meta" not in node.meta: + if not is_node_output_tensor(node): reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" else: total_output_bytes += get_size_of_node(submod, node)[0] diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index a996dc8b36521..8274f4bf3b625 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -48,6 +48,17 @@ def get_node_target(submodules: Dict[str, torch.nn.Module], node: torch.fx.Node) return node.target +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + + class FxNetAccFusionsFinder: """ Finds groups of connected ACC nodes that pass non-tensor data between each other. diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index c0b83bc5c3734..b25e45d206a51 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -7,11 +7,18 @@ from .graph import magic_methods, reflectable_magic_methods, Graph from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable from .node import Target, Node, Argument, base_types, map_aggregate +from ._compatibility import compatibility +from .operator_schemas import check_for_mutable_operation +@compatibility(is_backward_compatible=True) class TracerBase: graph: Graph record_stack_traces : bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations : bool = False + @compatibility(is_backward_compatible=True) def create_node(self, kind : str, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, type_expr : Optional[Any] = None) -> Node: @@ -22,13 +29,16 @@ def create_node(self, kind : str, target : Target, modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ + if kind == 'call_function' and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + return self.graph.create_node(kind, target, args, kwargs, name, type_expr) + @compatibility(is_backward_compatible=True) def proxy(self, node: Node) -> 'Proxy': return Proxy(node, self) - - + @compatibility(is_backward_compatible=True) def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr : Optional[Any] = None, proxy_factory_fn: Callable[[Node], 'Proxy'] = None): @@ -86,6 +96,7 @@ def _find_user_frame(self): return frame + @compatibility(is_backward_compatible=True) def create_arg(self, a: Any) -> Argument: """ A method that lowers the objects seen as arguments during symbolic evaluation @@ -131,6 +142,7 @@ def no_node(arg): raise NotImplementedError(f"argument of type: {type(a)}") + @compatibility(is_backward_compatible=True) def to_bool(self, obj: 'Proxy') -> bool: """Called when a proxy object is being converted to a boolean, such as when used in control flow. Normally we don't know what to do because @@ -139,6 +151,7 @@ def to_bool(self, obj: 'Proxy') -> bool: """ raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + @compatibility(is_backward_compatible=True) def iter(self, obj: 'Proxy') -> Iterator: """Called when a proxy object is being iterated over, such as when used in control flow. Normally we don't know what to do because @@ -154,6 +167,7 @@ def iter(self, obj: 'Proxy') -> Iterator: ' Proxy docstring for help troubleshooting ' 'Proxy iteration errors') + @compatibility(is_backward_compatible=True) def keys(self, obj: 'Proxy') -> Any: """Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an @@ -163,15 +177,17 @@ def keys(self, obj: 'Proxy') -> Any: # used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) class GraphAppendingTracer(TracerBase): def __init__(self, graph: Graph): super().__init__() self.graph = graph +@compatibility(is_backward_compatible=True) class TraceError(ValueError): pass - +@compatibility(is_backward_compatible=True) class Proxy: """ ``Proxy`` objects are ``Node`` wrappers that flow through the @@ -200,6 +216,8 @@ class Proxy: For a more detailed description into the Proxy internals, check out the "Proxy" section in `torch/fx/OVERVIEW.md` """ + + @compatibility(is_backward_compatible=True) def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): if tracer is None: # This allows you to create a Proxy object around a raw Node @@ -232,6 +250,7 @@ def __iter__(self) -> Iterable['Proxy']: def __bool__(self) -> bool: return self.tracer.to_bool(self) + @compatibility(is_backward_compatible=True) def keys(self): return self.tracer.keys(self) @@ -253,7 +272,9 @@ def __torch_function__(self, orig_method, types, args=None, kwargs=None): return self.tracer.create_proxy('call_function', orig_method, args, kwargs, name=self.tracer.graph._target_to_str(orig_method.__name__)) +@compatibility(is_backward_compatible=True) class Attribute(Proxy): + @compatibility(is_backward_compatible=True) def __init__(self, root: Proxy, attr: str): self.root = root self.attr = attr @@ -272,9 +293,10 @@ def __call__(self, *args, **kwargs): return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) +@compatibility(is_backward_compatible=False) class ParameterProxy(Proxy): """ - a special proxy which lets "shape", "size", "dim", and a few other + A special proxy which lets "shape", "size", "dim", and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing """ @@ -309,7 +331,7 @@ def nelement(self): for method in magic_methods: - def scope(method): + def _scope(method): def impl(*args, **kwargs): tracer = args[0].tracer target = getattr(operator, method) @@ -317,7 +339,7 @@ def impl(*args, **kwargs): impl.__name__ = method as_magic = f'__{method}__' setattr(Proxy, as_magic, impl) - scope(method) + _scope(method) def _define_reflectable(orig_method_name): method_name = f'__r{orig_method_name}__' diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index e779f6ca9e6b1..72ea56aa31196 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -2,22 +2,24 @@ from .graph import Graph from .node import Node from ._symbolic_trace import symbolic_trace +from ._compatibility import compatibility import copy from typing import Callable, Dict, List, NamedTuple, Optional, Set import torch +@compatibility(is_backward_compatible=True) class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node] -class SubgraphMatcher: +class _SubgraphMatcher: def __init__(self, pattern: Graph) -> None: self.pattern = pattern if len(pattern.nodes) == 0: - raise ValueError("SubgraphMatcher cannot be initialized with an " + raise ValueError("_SubgraphMatcher cannot be initialized with an " "empty pattern") # `self.pattern_anchor` is the output Node in `pattern` self.pattern_anchor = next(iter(reversed(pattern.nodes))) @@ -129,6 +131,7 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() +@compatibility(is_backward_compatible=True) def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: """ Matches all possible non-overlapping sets of operators and their @@ -242,7 +245,6 @@ def forward(self, x, w1, w2): max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2 - """ # Get the graphs for `gm`, `pattern`, `replacement` original_graph = gm.graph @@ -251,7 +253,7 @@ def forward(self, x, w1, w2): # Find all possible pattern matches in original_graph. Note that # pattern matches may overlap with each other. - matcher = SubgraphMatcher(pattern_graph) + matcher = _SubgraphMatcher(pattern_graph) matches: List[Match] = [] # Consider each node as an "anchor" (deepest matching graph node) diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index 18387ee3c78f7..0840122a9b168 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -1,6 +1,9 @@ from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from ._compatibility import compatibility + +@compatibility(is_backward_compatible=False) class TensorType: """ TensorType defines a type for tensors, which consists of a list of dimensions. @@ -48,7 +51,7 @@ def __repr__(self): Dyn = _DynType() - +@compatibility(is_backward_compatible=False) def is_consistent(t1, t2): """ A binary relation denoted by ~ that determines if t1 is consistent with t2. @@ -74,6 +77,7 @@ def is_consistent(t1, t2): return False +@compatibility(is_backward_compatible=False) def is_more_precise(t1, t2): """ A binary relation denoted by <= that determines if t1 is more precise than t2. diff --git a/torch/hub.py b/torch/hub.py index 499640b8bc6ee..82287d84b14f6 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -257,18 +257,20 @@ def set_dir(d): def list(github, force_reload=False, skip_validation=False): r""" - List all entrypoints available in `github` hubconf. + List all callable entrypoints available in the repo specified by ``github``. Args: github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional - tag/branch. The default branch is `master` if not specified. + tag/branch. The default branch is ``master`` if not specified. Example: 'pytorch/vision[:hub]' force_reload (bool, optional): whether to discard the existing cache and force a fresh download. - Default is `False`. - skip_validation (bool, optional): whether to check package validity against github. - Default is `False`. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. Returns: - entrypoints: a list of available entrypoint names + list: The available callables entrypoint Example: >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) @@ -277,7 +279,8 @@ def list(github, force_reload=False, skip_validation=False): sys.path.insert(0, repo_dir) - hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = import_module(MODULE_HUBCONF, hubconf_path) sys.path.remove(repo_dir) @@ -289,17 +292,19 @@ def list(github, force_reload=False, skip_validation=False): def help(github, model, force_reload=False, skip_validation=False): r""" - Show the docstring of entrypoint `model`. + Show the docstring of entrypoint ``model``. Args: github (string): a string with format with an optional - tag/branch. The default branch is `master` if not specified. + tag/branch. The default branch is ``master`` if not specified. Example: 'pytorch/vision[:hub]' - model (string): a string of entrypoint name defined in repo's hubconf.py + model (string): a string of entrypoint name defined in repo's ``hubconf.py`` force_reload (bool, optional): whether to discard the existing cache and force a fresh download. - Default is `False`. - skip_validation (bool, optional): whether to check package validity against github. - Default is `False`. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. Example: >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) """ @@ -307,7 +312,8 @@ def help(github, model, force_reload=False, skip_validation=False): sys.path.insert(0, repo_dir) - hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = import_module(MODULE_HUBCONF, hubconf_path) sys.path.remove(repo_dir) @@ -316,22 +322,19 @@ def help(github, model, force_reload=False, skip_validation=False): return entry.__doc__ -# Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`, -# but Python2 complains syntax error for it. We have to skip force_reload in function -# signature here but detect it in kwargs instead. -# TODO: fix it after Python2 EOL -def load(repo_or_dir, model, *args, **kwargs): +def load(repo_or_dir, model, *args, source='github', force_reload=False, verbose=True, skip_validation=False, + **kwargs): r""" Load a model from a github repo or a local directory. Note: Loading a model is the typical use case, but this can also be used to for loading other objects such as tokenizers, loss functions, etc. - If :attr:`source` is ``'github'``, :attr:`repo_or_dir` is expected to be + If ``source`` is 'github', ``repo_or_dir`` is expected to be of the form ``repo_owner/repo_name[:tag_name]`` with an optional - tag/branch. + tag/branch. The default branch is ``master`` if not specified. - If :attr:`source` is ``'local'``, :attr:`repo_or_dir` is expected to be a + If ``source`` is 'local', ``repo_or_dir`` is expected to be a path to a local directory. Args: @@ -340,9 +343,9 @@ def load(repo_or_dir, model, *args, **kwargs): ``source = 'local'``. model (string): the name of a callable (entrypoint) defined in the repo/dir's ``hubconf.py``. - *args (optional): the corresponding args for callable :attr:`model`. - source (string, optional): ``'github'`` | ``'local'``. Specifies how - ``repo_or_dir`` is to be interpreted. Default is ``'github'``. + *args (optional): the corresponding args for callable ``model``. + source (string, optional): 'github' or 'local'. Specifies how + ``repo_or_dir`` is to be interpreted. Default is 'github'. force_reload (bool, optional): whether to force a fresh download of the github repo unconditionally. Does not have any effect if ``source = 'local'``. Default is ``False``. @@ -350,13 +353,14 @@ def load(repo_or_dir, model, *args, **kwargs): local caches. Note that the message about first download cannot be muted. Does not have any effect if ``source = 'local'``. Default is ``True``. - skip_validation (bool, optional): whether to check package validity against github. - Default is `False`. - **kwargs (optional): the corresponding kwargs for callable - :attr:`model`. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + **kwargs (optional): the corresponding kwargs for callable ``model``. Returns: - The output of the :attr:`model` callable when called with the given + The output of the ``model`` callable when called with the given ``*args`` and ``**kwargs``. Example: @@ -367,10 +371,7 @@ def load(repo_or_dir, model, *args, **kwargs): >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', pretrained=True) """ - source = kwargs.pop('source', 'github').lower() - force_reload = kwargs.pop('force_reload', False) - verbose = kwargs.pop('verbose', True) - skip_validation = kwargs.pop('skip_validation', False) + source = source.lower() if source not in ('github', 'local'): raise ValueError( @@ -391,7 +392,7 @@ def _load_local(hubconf_dir, model, *args, **kwargs): hubconf_dir (string): path to a local directory that contains a ``hubconf.py``. model (string): name of an entrypoint defined in the directory's - `hubconf.py`. + ``hubconf.py``. *args (optional): the corresponding args for callable ``model``. **kwargs (optional): the corresponding kwargs for callable ``model``. @@ -420,8 +421,8 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): Args: url (string): URL of the object to download - dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` - hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. + dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. Default: None progress (bool, optional): whether or not to display a progress bar to stderr Default: True @@ -431,8 +432,6 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): """ file_size = None - # We use a different API for python2 since urllib(2) doesn't recognize the CA - # certificates in older Python req = Request(url, headers={"User-Agent": "torch.hub"}) u = urlopen(req) meta = u.info() @@ -519,8 +518,8 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr If the object is already present in `model_dir`, it's deserialized and returned. - The default value of `model_dir` is ``/checkpoints`` where - `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. Args: url (string): URL of the object to download @@ -533,7 +532,7 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False - file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. + file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set. Example: >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index c9fd886c7336d..f7fa58bd36434 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -20,7 +20,6 @@ ) from torch.jit._script import ( script, - _script_pdt, Attribute, ScriptModule, script_method, diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index cab6d3c8f71ef..582baf7422343 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -179,6 +179,18 @@ def optimize_for_inference(mod: ScriptModule) -> ScriptModule: This is still in prototype, and may have the potential to slow down your model. Primary use cases that have been targeted so far have been vision models on cpu and gpu to a lesser extent. + + Example (optimizing a module with Conv->Batchnorm):: + + import torch + in_channels, out_channels = 3, 32 + conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) + bn = torch.nn.BatchNorm2d(out_channels, eps=.001) + mod = torch.nn.Sequential(conv, bn) + frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval())) + assert "batch_norm" not in str(frozen_mod.graph) + # if built with MKLDNN, convolution will be run with MKLDNN weights + assert "MKLDNN" in frozen_mod.graph """ if not isinstance(mod, ScriptModule): raise RuntimeError( diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index b5a698eca7006..9957541ff25d1 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,7 +1,10 @@ + +import torch + import inspect import typing import pathlib -import torch +import sys from typing import Optional, Iterable, List, Dict from collections import defaultdict from types import CodeType @@ -16,25 +19,50 @@ except ImportError: _IS_MONKEYTYPE_INSTALLED = False -def get_optional_of_element_type(types: str): +# Checks whether a class is defind in `torch.*` modules +def is_torch_native_class(cls): + if not hasattr(cls, '__module__'): + return False + + parent_modules = cls.__module__.split('.') + if not parent_modules: + return False + + root_module = sys.modules.get(parent_modules[0]) + return root_module is torch + +def get_type(type): + """ + Helper function which converts the given type to a torchScript acceptable format. + """ + if isinstance(type, str): + return type + elif inspect.getmodule(type) == typing: + # If the type is a type imported from typing + # like Tuple, List, Dict then replace `typing.` + # with a null string. This needs to be done since + # typing.List is not accepted by TorchScript. + type_to_string = str(type) + return type_to_string.replace(type.__module__ + '.', '') + elif is_torch_native_class(type): + # If the type is a subtype of torch module, then TorchScript expects a fully qualified name + # for the type which is obtained by combining the module name and type name. + return type.__module__ + '.' + type.__name__ + else: + # For all other types use the name for the type. + return type.__name__ + +def get_optional_of_element_type(types): """ Helper function to extracts the type of the element to be annotated to Optional from the list of consolidated types and returns `Optional[element type]`. - TODO: To remove this check once Union support lands. """ - elements = types.split(",") - elem_type = elements[0] if 'NoneType' in elements[1] else elements[1] - - # If the type is from typing module, then extract the element type - start = elem_type.find("[") - end = elem_type.rfind("]") - if start != -1 and end != -1: - return elem_type[:start + 1] + 'Optional[' + elem_type[start + 1: end] + ']]' - - # Else return Optional[element type] - if elem_type == 'Tensor': - elem_type = 'torch.Tensor' + elem_type = types[1] if type(None) == types[0] else types[0] + elem_type = get_type(elem_type) + + # Optional type is internally converted to Union[type, NoneType], which + # is not supported yet in TorchScript. Hence, representing the optional type as string. return 'Optional[' + elem_type + ']' def get_qualified_name(func): @@ -88,30 +116,15 @@ def consolidate_types(self, qualified_name: str) -> Dict: # then consolidate the type to `Any` and replace the entry # by type `Any`. for arg, types in all_args.items(): - _all_type = " " - for _type in types: - # If the type is a type imported from typing - # like Tuple, List, Dict then replace "typing." - # with a null string. - if inspect.getmodule(_type) == typing: - _type_to_string = str(_type) - _all_type += _type_to_string.replace('typing.', '') + ',' - elif _type is torch.nn.parameter.Parameter: - # Check if the type is torch.nn.parameter.Parameter, - # use the entire quaalified name `torch.nn.parameter.Parameter` - # for type - _all_type += 'torch.nn.parameter.Parameter' + ',' - else: - _all_type += _type.__name__ + ',' - _all_type = _all_type.lstrip(" ") # Remove any trailing spaces - - if len(types) == 2 and 'NoneType' in _all_type: + types = list(types) + type_length = len(types) + if type_length == 2 and type(None) in types: # TODO: To remove this check once Union suppport in TorchScript lands. - all_args[arg] = {get_optional_of_element_type(_all_type)} - elif len(types) > 1: - all_args[arg] = {'Any'} - else: - all_args[arg] = {_all_type[:-1]} + all_args[arg] = get_optional_of_element_type(types) + elif type_length > 1: + all_args[arg] = 'Any' + elif type_length == 1: + all_args[arg] = get_type(types[0]) return all_args def get_args_types(self, qualified_name: str) -> Dict: @@ -157,7 +170,6 @@ def jit_code_filter(code: CodeType) -> bool: The custom CodeFilter is required while scripting a FX Traced forward calls. FX Traced forward calls have `code.co_filename` start with '<' which is used to exclude tracing of stdlib and site-packages in the default code filter. - Since we need all forward calls to be traced, this custom code filter checks for code.co_name to be 'forward' and enables tracing for all such calls. The code filter is similar to default code filter for monkeytype and diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 0c3e5ef7f0726..acc9e7c44f51f 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -449,7 +449,7 @@ def method_template(self, *args, **kwargs): setattr(RecursiveScriptClass, method_name, method_template) # this is a Python 'non-data descriptor' that causes the first access - # to ScriptModule's forward to lookup the forward method and stash + # to ScriptModule's forward to look up the forward method and stash # it in the objects dict. Due to the standard rules for attribute lookup, # subsequent lookups will just directly return the previously looked up method. # This is necessary because nn.Module defines forward as a method. If we @@ -785,13 +785,6 @@ def __setattr__(self, attr, value): # It's fairly trivial to save enough info to warn in this case. return super(RecursiveScriptModule, self).__setattr__(attr, value) - def __getstate__(self): - raise pickle.PickleError( - "ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. " - + "Mixed serialization of script and non-script modules is not supported. " - + "For purely script modules use my_script_module.save() instead." - ) - def __copy__(self): return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) @@ -912,6 +905,8 @@ def _get_methods(cls): "_tracing_name", "eval", "train", + "get_extra_state", + "set_extra_state" } def _make_fail(name): @@ -982,57 +977,6 @@ def call_prepare_scriptable_func(obj): memo: Dict[int, torch.nn.Module] = {} return call_prepare_scriptable_func_impl(obj, memo) - -def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, - example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None): - # This is a private API, intended for internal use only. Usage of this API is only for experimental - # purposes only and is highly discouraged. - global type_trace_db - if not _enabled: - return obj - - if optimize is not None: - warnings.warn( - "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" - ) - - # No-op for modules and functions that are already scripted - if isinstance(obj, ScriptModule): - return obj - if isinstance(obj, ScriptFunction): - return obj - - if example_inputs: - # If MonkeyType is installed, enable profile directed type annotation - # Check if example_inputs are defined and generate call traces - # for the method by running eager mode version of the method with - # the provide example inputs. This logs all the traces in type_trace_db - type_trace_db = JitTypeTraceStore() - if monkeytype_trace: - monkeytype_config = JitTypeTraceConfig(type_trace_db) - with monkeytype_trace(monkeytype_config): - if isinstance(example_inputs, Dict): - # If the obj is an nn.Module or a class, then each method is - # executed with the arguments provided in the example inputs. - # example inputs here will be of type Dict(class.method, (arguments)) - # This is used to infer type annotations for those methods - # which are not called directly under the hood of monkeytype. - for module, example_input in example_inputs.items(): - for example in example_input: - module(*example) - elif isinstance(example_inputs, List): - for examples in example_inputs: - obj(*examples) - else: - warnings.warn("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" - " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.") - else: - warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " - "to enable Profile-Directed Typing in TorchScript. Refer to " - "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ") - return script(obj, optimize, _frames_up, _rcb) - - def create_script_dict(obj): """ Create a ``torch._C.ScriptDict`` instance with the data from ``obj``. @@ -1063,7 +1007,8 @@ def create_script_list(obj, type_hint=None): return torch._C.ScriptList(obj) # type: ignore[attr-defined] -def script(obj, optimize=None, _frames_up=0, _rcb=None): +def script(obj, optimize=None, _frames_up=0, _rcb=None, + example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None): r""" Scripting a function or ``nn.Module`` will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or @@ -1081,6 +1026,8 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): Args: obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, class type, dictionary, or list to compile. + example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs + to annotate the arguments for a function or ``nn.Module``. Returns: If ``obj`` is ``nn.Module``, ``script`` returns @@ -1122,6 +1069,34 @@ def foo(x, y): ... + ****Scripting a function using example_inputs** + Example inputs can be used to annotate a function arguments. + + Example (annotating a function before scripting): + + .. testcode:: + + import torch + + def test_sum(a, b): + return a + b + + # Annotate the arguments to be int + scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) + + print(type(scripted_fn)) # torch.jit.ScriptFunction + + # See the compiled graph as Python code + print(scripted_fn.code) + + # Call the function using the TorchScript interpreter + scripted_fn(20, 100) + + .. testoutput:: + :hide: + + ... + **Scripting an nn.Module** Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses @@ -1208,7 +1183,30 @@ def forward(self, input): scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2))) + + Example ( Annotating forward of nn.Module using example_inputs):: + + import torch + import torch.nn as nn + from typing import NamedTuple + + class MyModule(NamedTuple): + result: List[int] + + class TestNNModule(torch.nn.Module): + def forward(self, a) -> MyModule: + result = MyModule(result=a) + return result + + pdt_model = TestNNModule() + + # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward + scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) + + # Run the scripted_model with actual inputs + print(scripted_model([20])) """ + global type_trace_db if not _enabled: return obj @@ -1225,6 +1223,35 @@ def forward(self, input): if isinstance(obj, ScriptFunction): return obj + if example_inputs: + # If MonkeyType is installed, enable profile directed type annotation + # Check if example_inputs are defined and generate call traces + # for the method by running eager mode version of the method with + # the provide example inputs. This logs all the traces in type_trace_db + type_trace_db = JitTypeTraceStore() + if monkeytype_trace: + monkeytype_config = JitTypeTraceConfig(type_trace_db) + with monkeytype_trace(monkeytype_config): + if isinstance(example_inputs, Dict): + # If the obj is an nn.Module or a class, then each method is + # executed with the arguments provided in the example inputs. + # example inputs here will be of type Dict(class.method, (arguments)) + # This is used to infer type annotations for those methods + # which are not called directly under the hood of monkeytype. + for module, example_input in example_inputs.items(): + for example in example_input: + module(*example) + elif isinstance(example_inputs, List): + for examples in example_inputs: + obj(*examples) + else: + raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" + " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.") + else: + warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " + "to enable Profile-Directed Typing in TorchScript. Refer to " + "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ") + if isinstance(obj, torch.nn.Module): obj = call_prepare_scriptable_func(obj) return torch.jit._recursive.create_script_module( diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 069b73e847d5a..5a2f6e5e0c487 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -24,7 +24,7 @@ from torch.autograd import function from torch.nn import Module -from torch.testing._core import _get_default_tolerance +from torch.testing._asserts import _get_default_rtol_and_atol _flatten = torch._C._jit_flatten _unflatten = torch._C._jit_unflatten @@ -417,7 +417,7 @@ def graph_diagnostic_info(): check_tensor_val = n_check.t("value") try: - torch.testing.assert_allclose(mod_tensor_val, check_tensor_val) + torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True) except (RuntimeError, AssertionError) as e: if tensor_compare_errors is None: tensor_compare_errors = "" @@ -489,11 +489,12 @@ def compare_outputs(original, reference, match_what): orig = orig.to_dense() if ref.is_mkldnn: ref = ref.to_dense() - torch.testing.assert_allclose( + torch.testing.assert_close( orig.double(), ref.double(), rtol=check_tolerance, - atol=_get_default_tolerance(orig, ref)[1], + atol=_get_default_rtol_and_atol(orig, ref)[1], + equal_nan=True, ) except AssertionError as e: maybe_warn_nondeterministic() diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index f2cf78949b47d..b189f36c4107f 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -6,13 +6,13 @@ import torch import warnings from .._jit_internal import List, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \ - is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn + is_optional, _qualified_name, Any, Future, is_future, is_ignored_fn, Union, is_union from .._jit_internal import BroadcastingList1, BroadcastingList2, BroadcastingList3 # type: ignore[attr-defined] from ._state import _get_script_class from torch._C import TensorType, TupleType, FloatType, IntType, ComplexType, \ - ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, NoneType, \ - DeviceObjType, StreamObjType, FutureType, EnumType + ListType, StringType, DictType, BoolType, OptionalType, InterfaceType, AnyType, \ + NoneType, DeviceObjType, StreamObjType, FutureType, EnumType, UnionType from textwrap import dedent @@ -45,7 +45,8 @@ class EvalEnv(object): 'List': List, 'Dict': Dict, 'Optional': Optional, - 'Future': Future, + 'Union': Union, + 'Future': Future } def __init__(self, rcb): @@ -245,6 +246,9 @@ def split_type_line(type_line): def try_real_annotations(fn, loc): """Tries to use the Py3.5+ annotation syntax to get the type.""" try: + # Note: anything annotated as `Optional[T]` will automatically + # be returned as `Union[T, None]` per + # https://github.com/python/typing/blob/master/src/typing.py#L850 sig = inspect.signature(fn) except ValueError: return None @@ -276,7 +280,6 @@ def get_enum_value_type(e: Type[enum.Enum], loc): return torch._C.unify_type_list(ir_types) def is_tensor(ann): - if issubclass(ann, torch.Tensor): return True @@ -326,6 +329,19 @@ def try_ann_to_type(ann, loc): msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) + if is_union(ann): + inner: List = [] + # We need these extra checks because both `None` and invalid + # values will return `None` + # TODO: Determine if the other cases need to be fixed as well + for a in ann.__args__: + if a is None: + inner.append(NoneType.get()) + maybe_type = try_ann_to_type(a, loc) + msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." + assert maybe_type, msg.format(repr(ann), repr(maybe_type)) + inner.append(maybe_type) + return UnionType(inner) # type: ignore[arg-type] if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): @@ -390,6 +406,8 @@ def ann_to_type(ann, loc): 'is_list', 'Dict', 'is_dict', + 'is_optional', + 'is_union', 'TensorType', 'TupleType', 'FloatType', diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index b0228b132980a..6053ee7ee7f63 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -337,9 +337,9 @@ def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): raise NotSupportedError(ctx_range, _vararg_kwarg_err) # List of Tuple of args and type as inferred by profile directed typing - arg_and_types = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) + arg_and_types = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) for arg in py_args.args] - arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) + arg_and_types_kwonlyargs = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) for arg in py_args.kwonlyargs] result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) @@ -452,6 +452,7 @@ def get_default_args(fn): return {} signature = inspect.signature(fn) + return { k: v.default for k, v in signature.parameters.items() diff --git a/torch/library.h b/torch/library.h index ce2bb92e5723e..a873b4226dbca 100644 --- a/torch/library.h +++ b/torch/library.h @@ -317,8 +317,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { return c10::DispatchKey::Meta; case c10::DeviceType::HIP: return c10::DispatchKey::HIP; - case c10::DeviceType::MSNPU: - return c10::DispatchKey::MSNPU; + case c10::DeviceType::ORT: + return c10::DispatchKey::ORT; case c10::DeviceType::HPU: return c10::DispatchKey::HPU; default: diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index df3507f1b3561..f98930e471630 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1284,6 +1284,12 @@ tensor([ 3.1623, 10.0000, 17.2627]) """) +matmul = _add_docstr(_linalg.linalg_matmul, r""" +linalg.matmul(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.matmul` +""") + multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" linalg.multi_dot(tensors, *, out=None) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 52125864000f1..4b0449c8f5672 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -442,10 +442,10 @@ def fractional_max_pool2d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if has_torch_function_unary(input): + if has_torch_function_variadic(input, _random_samples): return handle_torch_function( fractional_max_pool2d_with_indices, - (input,), + (input, _random_samples), input, kernel_size, output_size=output_size, @@ -473,10 +473,10 @@ def _fractional_max_pool2d( return_indices: bool = False, _random_samples: Optional[Tensor] = None ) -> Tensor: - if has_torch_function_unary(input): + if has_torch_function_variadic(input, _random_samples): return handle_torch_function( fractional_max_pool2d, - (input,), + (input, _random_samples), input, kernel_size, output_size=output_size, @@ -537,10 +537,10 @@ def fractional_max_pool3d_with_indices( .. _Fractional MaxPooling: http://arxiv.org/abs/1412.6071 """ - if has_torch_function_unary(input): + if has_torch_function_variadic(input, _random_samples): return handle_torch_function( fractional_max_pool3d_with_indices, - (input,), + (input, _random_samples), input, kernel_size, output_size=output_size, @@ -571,10 +571,10 @@ def _fractional_max_pool3d( return_indices: bool = False, _random_samples: Optional[Tensor] = None ) -> Tensor: - if has_torch_function_unary(input): + if has_torch_function_variadic(input, _random_samples): return handle_torch_function( fractional_max_pool3d, - (input,), + (input, _random_samples), input, kernel_size, output_size=output_size, @@ -1843,8 +1843,8 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens - Bias: :math:`(out\_features)` - Output: :math:`(N, *, out\_features)` """ - if has_torch_function_variadic(input, weight): - return handle_torch_function(linear, (input, weight), input, weight, bias=bias) + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias) return torch._C._nn.linear(input, weight, bias) @@ -1865,10 +1865,10 @@ def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tens - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` and all but the last dimension are the same shape as the input. """ - if has_torch_function_variadic(input1, input2, weight): + if has_torch_function_variadic(input1, input2, weight, bias): return handle_torch_function( bilinear, - (input1, input2, weight), + (input1, input2, weight, bias), input1, input2, weight, bias=bias ) @@ -2135,10 +2135,10 @@ def embedding_bag( tensor([[ 0.0000, 0.0000, 0.0000], [-0.7082, 3.2145, -2.6251]]) """ - if has_torch_function_variadic(input, weight): + if has_torch_function_variadic(input, weight, offsets, per_sample_weights): return handle_torch_function( embedding_bag, - (input, weight), + (input, weight, offsets, per_sample_weights), input, weight, offsets=offsets, @@ -2263,10 +2263,10 @@ def batch_norm( See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, :class:`~torch.nn.BatchNorm3d` for details. """ - if has_torch_function_unary(input): + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): return handle_torch_function( batch_norm, - (input,), + (input, running_mean, running_var, weight, bias), input, running_mean, running_var, @@ -2309,10 +2309,10 @@ def instance_norm( See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, :class:`~torch.nn.InstanceNorm3d` for details. """ - if has_torch_function_unary(input): + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): return handle_torch_function( instance_norm, - (input,), + (input, running_mean, running_var, weight, bias), input, running_mean=running_mean, running_var=running_var, @@ -2340,9 +2340,9 @@ def layer_norm( See :class:`~torch.nn.LayerNorm` for details. """ - if has_torch_function_unary(input): + if has_torch_function_variadic(input, weight, bias): return handle_torch_function( - layer_norm, (input,), input, normalized_shape, weight=weight, bias=bias, eps=eps + layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps ) return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2354,8 +2354,8 @@ def group_norm( See :class:`~torch.nn.GroupNorm` for details. """ - if has_torch_function_unary(input): - return handle_torch_function(group_norm, (input,), input, num_groups, weight=weight, bias=bias, eps=eps) + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) @@ -2515,10 +2515,10 @@ def nll_loss( >>> output = F.nll_loss(F.log_softmax(input), target) >>> output.backward() """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight): return handle_torch_function( nll_loss, - (input, target), + (input, target, weight), input, target, weight=weight, @@ -2772,6 +2772,7 @@ def cross_entropy( ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> Tensor: r"""This criterion computes the cross entropy loss between input and target. @@ -2808,6 +2809,10 @@ def cross_entropy( elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: @@ -2823,10 +2828,10 @@ def cross_entropy( >>> loss = F.cross_entropy(input, target) >>> loss.backward() """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight): return handle_torch_function( cross_entropy, - (input, target), + (input, target, weight), input, target, weight=weight, @@ -2834,10 +2839,11 @@ def cross_entropy( ignore_index=ignore_index, reduce=reduce, reduction=reduction, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) def binary_cross_entropy( @@ -2881,10 +2887,10 @@ def binary_cross_entropy( >>> loss = F.binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight): return handle_torch_function( binary_cross_entropy, - (input, target), + (input, target, weight), input, target, weight=weight, @@ -2953,10 +2959,10 @@ def binary_cross_entropy_with_logits( >>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss.backward() """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight, pos_weight): return handle_torch_function( binary_cross_entropy_with_logits, - (input, target), + (input, target, weight, pos_weight), input, target, weight=weight, @@ -3237,10 +3243,10 @@ def multilabel_soft_margin_loss( See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight): return handle_torch_function( multilabel_soft_margin_loss, - (input, target), + (input, target, weight), input, target, weight=weight, @@ -3317,10 +3323,10 @@ def multi_margin_loss( See :class:`~torch.nn.MultiMarginLoss` for details. """ - if has_torch_function_variadic(input, target): + if has_torch_function_variadic(input, target, weight): return handle_torch_function( multi_margin_loss, - (input, target), + (input, target, weight), input, target, p=p, @@ -4437,8 +4443,8 @@ def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, o out (Tensor, optional): the output tensor. If :attr:`out` is used, this operation won't be differentiable. """ - if has_torch_function_unary(input): - return handle_torch_function(normalize, (input,), input, p=p, dim=dim, eps=eps, out=out) + if has_torch_function_variadic(input, out): + return handle_torch_function(normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 828f8df2185b5..cbd05d7e3dedb 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -239,7 +239,8 @@ def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., re def cross_entropy(input: Tensor, target: Tensor, weight: Optional[Tensor] = ..., size_average: Optional[bool] = ..., - ignore_index: int = ..., reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ... + ignore_index: int = ..., reduce: Optional[bool] = ..., reduction: str = ..., + label_smoothing: float = ...) -> Tensor: ... def binary_cross_entropy(input: Tensor, target: Tensor, weight: Optional[Tensor] = ..., diff --git a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py b/torch/nn/intrinsic/quantized/_reference/modules/__init__.py deleted file mode 100644 index bf8ff3a3db5e1..0000000000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -from .linear_relu import LinearReLU -from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d - -__all__ = [ - 'LinearReLU', - 'ConvReLU1d', - 'ConvReLU2d', - 'ConvReLU3d', -] diff --git a/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py deleted file mode 100644 index b0305f6207d95..0000000000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/conv_relu.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn.quantized._reference as nnqr -import torch.nn.functional as F - -class ConvReLU1d(nnqr.Conv1d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv1d( - x_dequant, weight_dequant, self._bias, self._conv1d_stride, # type: ignore[has-type] - self._conv1d_padding, self._conv1d_dilation, self.groups) # type: ignore[has-type] - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU1d(Reference)" - - -class ConvReLU2d(nnqr.Conv2d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv2d( - x_dequant, weight_dequant, self._bias, self.stride, - self.padding, self.dilation, self.groups) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU2d(Reference)" - -class ConvReLU3d(nnqr.Conv3d): - _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv3d( - x_dequant, weight_dequant, self._bias, self.stride, - self.padding, self.dilation, self.groups) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedConvReLU3d(Reference)" diff --git a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py b/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py deleted file mode 100644 index 39c595376fded..0000000000000 --- a/torch/nn/intrinsic/quantized/_reference/modules/linear_relu.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn.intrinsic as nni -import torch.nn.quantized._reference as nnqr -import torch.nn.functional as F - -class LinearReLU(nnqr.Linear): - _FLOAT_MODULE = nni.LinearReLU - - def __init__( - self, - in_features, - out_features, - bias=True, - dtype=torch.qint8): - super().__init__(in_features, out_features, bias, dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.linear(x_dequant, weight_dequant, self._bias) - float_result = F.relu(float_result, inplace=True) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) - return result - - def _get_name(self): - return "QuantizedLinearReLU(Reference)" diff --git a/torch/nn/intrinsic/quantized/_reference/__init__.py b/torch/nn/intrinsic/quantized/dynamic/__init__.py similarity index 100% rename from torch/nn/intrinsic/quantized/_reference/__init__.py rename to torch/nn/intrinsic/quantized/dynamic/__init__.py diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000..ce571862b4275 --- /dev/null +++ b/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +import torch +from .linear_relu import LinearReLU + +__all__ = [ + 'LinearReLU', +] diff --git a/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000..c30b3109ef601 --- /dev/null +++ b/torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,46 @@ +import torch +import torch.nn.quantized.dynamic as nnqd +import torch.nn.intrinsic as nni + +class LinearReLU(nnqd.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules that can be used + for dynamic quantization. + Supports both, FP16 and INT8 quantization. + + We adopt the same interface as :class:`torch.nn.quantized.dynamic.Linear`. + + Attributes: + Same as torch.nn.quantized.dynamic.Linear + + Examples:: + + >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._packed_params.dtype == torch.qint8: + # TODO check if we should set reduce_rage = True by default here + Y = torch.ops.quantized.linear_relu_dynamic( + x, self._packed_params._packed_params, reduce_range=True) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_relu_dynamic_fp16( + x, self._packed_params._packed_params) + else: + raise RuntimeError('Unsupported dtype on dynamic quantized linear relu!') + return Y.to(x.dtype) + + def _get_name(self): + return 'DynamicQuantizedLinearReLU' + + @classmethod + def from_float(cls, mod): + return super(LinearReLU, cls).from_float(mod) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 91427c8aea2cd..90b901d9b690a 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1236,8 +1236,8 @@ class Softmax2d(Module): apply `Softmax` to each location :math:`(Channels, h_i, w_j)` Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) Returns: a Tensor of the same dimension and shape as the input with @@ -1252,8 +1252,8 @@ class Softmax2d(Module): """ def forward(self, input: Tensor) -> Tensor: - assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' - return F.softmax(input, 1, _stacklevel=5) + assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input' + return F.softmax(input, -3, _stacklevel=5) class LogSoftmax(Module): diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 07fe1063283fc..21425f2be2aad 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -16,6 +16,10 @@ class Identity(Module): args: any argument (unused) kwargs: any keyword argument (unused) + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + Examples:: >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 7f39db405c861..d72c614c88048 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -164,10 +164,11 @@ class NLLLoss(_WeightedLoss): :attr:`reduction`. Default: ``'mean'`` Shape: - - Input: :math:`(N, C)` where `C = number of classes`, or + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or @@ -1103,6 +1104,10 @@ class probabilities only when a single class label per minibatch item is too res and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(N, C)` where `C = number of classes`, or @@ -1131,17 +1136,20 @@ class probabilities only when a single class label per minibatch item is too res >>> output = loss(input, target) >>> output.backward() """ - __constants__ = ['ignore_index', 'reduction'] + __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] ignore_index: int + label_smoothing: float def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, - reduce=None, reduction: str = 'mean') -> None: + reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.cross_entropy(input, target, weight=self.weight, - ignore_index=self.ignore_index, reduction=self.reduction) + ignore_index=self.ignore_index, reduction=self.reduction, + label_smoothing=self.label_smoothing) class MultiLabelSoftMarginLoss(_WeightedLoss): @@ -1322,7 +1330,7 @@ class MultiMarginLoss(_WeightedLoss): The loss function then becomes: .. math:: - \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p)}{\text{x.size}(0)} + \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)} Args: p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2` @@ -1346,6 +1354,20 @@ class MultiMarginLoss(_WeightedLoss): elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes. + - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target. + + Examples:: + + >>> loss = nn.MultiMarginLoss() + >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]]) + >>> y = torch.tensor([3]) + >>> loss(x, y) + >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + tensor(0.3250) """ __constants__ = ['p', 'margin', 'reduction'] margin: float diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f4ef4533de600..28b220e24037f 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -46,6 +46,8 @@ def _addindent(s_, numSpaces): _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() _global_forward_hooks: Dict[int, Callable] = OrderedDict() +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: r"""Registers a forward pre-hook common to all modules. @@ -145,13 +147,6 @@ def register_module_full_backward_hook( This adds global state to the `nn.module` module and it is only intended for debugging/profiling purposes. - The current implementation will not have the presented behavior - for complex :class:`Module` that perform many operations. - In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only - contain the gradients for a subset of the inputs and outputs. - For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` - directly on a specific input or output to get the required gradients. - The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: @@ -165,6 +160,10 @@ def register_module_full_backward_hook( in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + Global hooks are called before hooks registered with `register_backward_hook` Returns: @@ -531,6 +530,41 @@ def get_buffer(self, target: str) -> "Tensor": return buffer + def get_extra_state(self) -> Any: + """ + Returns any extra state to include in the module's state_dict. + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be pickleable to ensure working serialization + of the state_dict. We only provide provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md " + "to report this bug.") + + def set_extra_state(self, state: Any): + """ + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md " + "to report this bug.") + def _apply(self, fn): for module in self.children(): module._apply(fn) @@ -907,6 +941,10 @@ def register_full_backward_hook( in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments. + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + .. warning :: Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error. @@ -1227,6 +1265,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # back that same object. But if they pass nothing, an `OrederedDict` is created and returned. @@ -1364,9 +1405,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, elif strict: missing_keys.append(key) + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + if strict: for key in state_dict.keys(): - if key.startswith(prefix): + if key.startswith(prefix) and key != extra_state_key: input_name = key[len(prefix):] input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index d09e257452e44..3665e893fa5ec 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -882,8 +882,8 @@ class LPPool1d(_LPPoolNd): ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape Shape: - - Input: :math:`(N, C, L_{in})` - - Output: :math:`(N, C, L_{out})`, where + - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`. + - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where .. math:: L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index ed5c3656203ee..60d21431dc5bf 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1408,6 +1408,19 @@ def _check_comm_hook(self, hook): "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].", ) + if ( + hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"] + and + ( + torch.version.cuda is None + or int(torch.version.cuda.split('.')[0]) < 11 + or not dist.is_available() + or not dist.is_nccl_available() + or torch.cuda.nccl.version() < (2, 9, 7) + ) + ): + self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.9.7+.") + @property def _distributed_rank(self): return dist.get_rank(self.process_group) diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py index 32e813ce94eae..bdfd7788533b5 100644 --- a/torch/nn/quantizable/modules/rnn.py +++ b/torch/nn/quantizable/modules/rnn.py @@ -407,6 +407,7 @@ def from_float(cls, other, qconfig=None): observed = torch.quantization.prepare(observed, inplace=True) return observed - def from_observed(self, other): - return torch.quantization.convert(self, inplace=False, + @classmethod + def from_observed(cls, other): + return torch.quantization.convert(other, inplace=False, remove_qconfig=True) diff --git a/torch/nn/quantized/_reference/modules/conv.py b/torch/nn/quantized/_reference/modules/conv.py index 036f8e46212c5..6b03bb0491ad1 100644 --- a/torch/nn/quantized/_reference/modules/conv.py +++ b/torch/nn/quantized/_reference/modules/conv.py @@ -1,42 +1,101 @@ import torch -import torch.nn.quantized as nnq +import torch.nn as nn import torch.nn.functional as F -from typing import Optional +from typing import Optional, Dict, Any from torch.nn.common_types import _size_1_t -from torch.nn.modules.utils import _single +from .utils import _quantize_and_dequantize_weight +from .utils import _save_weight_qparams +from .utils import _get_weight_qparam_keys -class _ConvNd(nnq._ConvNd): +class _ConvNd(torch.nn.modules.conv._ConvNd): """ A reference version of nn.quantized.Conv2d we will not pack the parameters in this module, since weight packing is an optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), this is useful when user want to use this module in other backends like Glow. """ - __annotations__ = {"_bias": Optional[torch.Tensor]} + __annotations__ = {"bias": Optional[torch.Tensor]} def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + '_qweight'] = self._qweight - destination[prefix + '_bias'] = self._bias + _save_weight_qparams( + destination, prefix, self.weight_qscheme, self.weight_dtype, + self.weight_scale, self.weight_zero_point, self.weight_axis) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self._qweight = state_dict[prefix + '_qweight'] - self._bias = state_dict[prefix + '_bias'] - state_dict.pop(prefix + '_qweight') - state_dict.pop(prefix + '_bias') + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - def _weight_bias(self): - return self._qweight, self._bias - - def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._qweight = w - self._bias = b - -class Conv1d(_ConvNd, nnq.Conv1d): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + self.weight_qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module") + if self.weight_qscheme is not None: + self.register_buffer( + "weight_scale", + torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) + self.register_buffer( + "weight_zero_point", + torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device)) + if self.weight_qscheme == torch.per_channel_affine: + self.register_buffer( + "weight_axis", + torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device)) + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # supress mypy warning + assert isinstance(self.weight, torch.Tensor) + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + assert isinstance(self.weight_axis, torch.Tensor) + return _quantize_and_dequantize_weight( + self.weight, self.weight_qscheme, + self.weight_dtype, self.weight_scale, self.weight_zero_point, self.weight_axis) + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + +class Conv1d(_ConvNd, nn.Conv1d): def __init__(self, in_channels: int, out_channels: int, @@ -46,91 +105,107 @@ def __init__(self, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros'): - nnq.Conv1d.__init__( + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv1d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode) - # self.stride, self.padding, self.dilation are 2d tuple since - # current quantized conv1d is using Conv2dPackedParams - # TODO: we should fix this if we implemenet Conv1dPackedParams - self._conv1d_stride = _single(self.stride[0]) - self._conv1d_padding = _single(self.padding[0]) - self._conv1d_dilation = _single(self.dilation[0]) + groups, bias, padding_mode, device, dtype) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv1d( - x_dequant, weight_dequant, self._bias, self._conv1d_stride, - self._conv1d_padding, self._conv1d_dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv1d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + weight_dequant = self.get_weight() + result = F.conv1d( + x, weight_dequant, self.bias, self.stride, + self.padding, self.dilation, self.groups) return result def _get_name(self): - return 'QuantizedConv1d(Reference)' - - @torch.jit.export - def __setstate__(self, state): - self.in_channels = state[0] - self.out_channels = state[1] - self.kernel_size = state[2] - self.stride = state[3] - self.padding = state[4] - self.dilation = state[5] - self.transposed = state[6] - self.output_padding = state[7] - self.groups = state[8] - self.padding_mode = state[9] - self.set_weight_bias(state[10], state[11]) - self.scale = state[12] - self.zero_point = state[13] - self.training = state[14] - self._conv1d_stride = (self.stride[0],) - self._conv1d_padding = (self.padding[0],) - self._conv1d_dilation = (self.dilation[0],) - -class Conv2d(_ConvNd, nnq.Conv2d): + return "QuantizedConv1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) + +class Conv2d(_ConvNd, nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): - nnq.Conv2d.__init__( + padding_mode='zeros', + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv2d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode) + groups, bias, padding_mode, device, dtype) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv2d( - x_dequant, weight_dequant, self._bias, self.stride, + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv2d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + weight_dequant = self.get_weight() + result = F.conv2d( + x, weight_dequant, self.bias, self.stride, self.padding, self.dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) return result def _get_name(self): - return 'QuantizedConv2d(Reference)' + return "QuantizedConv2d(Reference)" -class Conv3d(_ConvNd, nnq.Conv3d): + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) + +class Conv3d(_ConvNd, nn.Conv3d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, - padding_mode='zeros'): - nnq.Conv3d.__init__( + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[Dict[str, Any]] = None): + nn.Conv3d.__init__( self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) + self._init_weight_qparams(weight_qparams, device) def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.conv3d( - x_dequant, weight_dequant, self._bias, self.stride, + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv3d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + weight_dequant = self.get_weight() + result = F.conv3d( + x, weight_dequant, self.bias, self.stride, self.padding, self.dilation, self.groups) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) return result def _get_name(self): - return 'QuantizedConv3d(Reference)' + return "QuantizedConv3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): + return _ConvNd.from_float(cls, float_conv, weight_qparams) diff --git a/torch/nn/quantized/_reference/modules/linear.py b/torch/nn/quantized/_reference/modules/linear.py index 276dc0161ded8..1df5499433d1c 100644 --- a/torch/nn/quantized/_reference/modules/linear.py +++ b/torch/nn/quantized/_reference/modules/linear.py @@ -1,51 +1,115 @@ import torch -import torch.nn.quantized as nnq +import torch.nn as nn import torch.nn.functional as F -from typing import Optional +from typing import Optional, Dict, Any +from .utils import _quantize_and_dequantize_weight +from .utils import _save_weight_qparams +from .utils import _get_weight_qparam_keys -class Linear(nnq.Linear): - """ A backend independent version of nn.quantized.Linear - we will not pack the parameters in this module, since weight packing is an - optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), - this is useful when user want to use this module in other backends like Glow. +class Linear(nn.Linear): + """ A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. """ - def __init__(self, in_features, out_features, bias_=True, - dtype=torch.qint8): - super().__init__(in_features, out_features, bias_, dtype) - self._qweight, self._bias = self._packed_params._weight_bias() - del self._packed_params + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_qparams: Optional[Dict[str, Any]] = None): + super().__init__(in_features, out_features, bias_, device, dtype) + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + self.weight_qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized linear module") + if self.weight_qscheme is not None: + self.register_buffer( + "weight_scale", + torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) + self.register_buffer( + "weight_zero_point", + torch.tensor( + weight_qparams["zero_point"], + dtype=torch.int, device=device)) + if self.weight_qscheme == torch.per_channel_affine: + self.register_buffer( + "weight_axis", + torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", + torch.tensor(0, dtype=torch.int, device=device)) def _get_name(self): return "QuantizedLinear(Reference)" + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # supress mypy warning + assert isinstance(self.weight, torch.Tensor) + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + assert isinstance(self.weight_axis, torch.Tensor) + return _quantize_and_dequantize_weight( + self.weight, self.weight_qscheme, self.weight_dtype, self.weight_scale, + self.weight_zero_point, self.weight_axis) + def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dequant = x.dequantize() - weight_dequant = self._qweight.dequantize() - float_result = F.linear(x_dequant, weight_dequant, self._bias) - # NEEDFIX: we don't have dtype in the Linear module APIs right now! - result = torch.quantize_per_tensor( - float_result, self.scale, self.zero_point, torch.quint8) + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_dequant = self.get_weight() + result = F.linear(x, weight_dequant, self.bias) return result def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + '_qweight'] = self._qweight - destination[prefix + '_bias'] = self._bias + _save_weight_qparams( + destination, prefix, self.weight_qscheme, self.weight_dtype, + self.weight_scale, self.weight_zero_point, self.weight_axis) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self._qweight = state_dict[prefix + '_qweight'] - self._bias = state_dict[prefix + '_bias'] - state_dict.pop(prefix + '_qweight') - state_dict.pop(prefix + '_bias') + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - def _weight_bias(self): - return self._qweight, self._bias - - def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: - self._qweight = w - self._bias = b + @classmethod + def from_float(cls, float_linear, weight_qparams): + qref_linear = Linear( + float_linear.in_features, float_linear.out_features, + float_linear.bias is not None, device=float_linear.weight.device, + dtype=float_linear.weight.dtype, weight_qparams=weight_qparams) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/torch/nn/quantized/_reference/modules/utils.py b/torch/nn/quantized/_reference/modules/utils.py new file mode 100644 index 0000000000000..7c366503dd872 --- /dev/null +++ b/torch/nn/quantized/_reference/modules/utils.py @@ -0,0 +1,45 @@ +import torch +from typing import Dict, Any + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: torch.Tensor): + """ Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme == torch.per_tensor_affine: + weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) + weight_dequant = weight.dequantize() + elif weight_qscheme == torch.per_channel_affine: + weight = torch.quantize_per_channel( + weight, weight_scale, + weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type] + weight_dequant = weight.dequantize() + else: + weight_dequant = weight + return weight_dequant + +def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + +def _get_weight_qparam_keys( + state_dict: Dict[str, Any], + prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 07cfdfe2846cc..ee153b10d2466 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -1,5 +1,6 @@ import torch import torch.nn.quantized as nnq +import torch.nn.intrinsic as nni from torch.nn.quantized.modules.utils import _quantize_weight class Linear(nnq.Linear): @@ -79,11 +80,15 @@ def from_float(cls, mod): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.nn.intrinsic.modules.fused.LinearReLU] + assert type(mod) in float_modules, \ 'nn.quantized.dynamic.Linear.from_float only works for one of' + \ str([float_mod.__name__ for float_mod in float_modules]) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + if type(mod) == nni.LinearReLU: + mod = mod[0] if mod.qconfig is not None and mod.qconfig.weight is not None: weight_observer = mod.qconfig.weight() else: @@ -102,6 +107,6 @@ def from_float(cls, mod): qweight = mod.weight.float() else: raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!') - qlinear = Linear(mod.in_features, mod.out_features, dtype=dtype) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) qlinear.set_weight_bias(qweight, mod.bias) return qlinear diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 4abd2115e4125..4df775105ba82 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -94,6 +94,16 @@ def __setstate__(self, state): self.set_weight_bias(state[0], state[1]) self.training = state[2] + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + torch.nn.Module.__init__(new_instance) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def __copy__(self): + return self.__deepcopy__({}) + def __repr__(self): return self._weight_bias().__repr__() diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 7941f41f19cac..de67aa814f39c 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -1,10 +1,286 @@ +from enum import Enum, auto + import torch +from torch import Tensor from ..utils import parametrize from ..modules import Module from .. import functional as F from typing import Optional + +def _is_orthogonal(Q, eps=None): + n, k = Q.size(-2), Q.size(-1) + Id = torch.eye(k, dtype=Q.dtype, device=Q.device) + # A reasonable eps, but not too large + eps = 10. * n * torch.finfo(Q.dtype).eps + return torch.allclose(Q.transpose(-2, -1).conj() @ Q, Id, atol=eps) + + +def _make_orthogonal(A): + """ Assume that A is a tall matrix. + Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative + """ + X, tau = torch.geqrf(A) + Q = torch.linalg.householder_product(X, tau) + # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs + Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + return Q + + +class _OrthMaps(Enum): + matrix_exp = auto() + cayley = auto() + householder = auto() + + +class _Orthogonal(Module): + base: Tensor + + def __init__(self, + weight, + orthogonal_map: _OrthMaps, + *, + use_trivialization=True) -> None: + super().__init__() + + # Note [Householder complex] + # For complex tensors, it is not possible to compute the tensor `tau` necessary for + # linalg.householder_product from the reflectors. + # To see this, note that the reflectors have a shape like: + # 0 0 0 + # * 0 0 + # * * 0 + # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters + # to parametrize the unitary matrices. Saving tau on its own does not work either, because + # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise + # them as independent tensors we would not maintain the constraint + # An equivalent reasoning holds for rectangular matrices + if weight.is_complex() and orthogonal_map == _OrthMaps.householder: + raise ValueError("The householder parametrization does not support complex tensors.") + + self.shape = weight.shape + self.orthogonal_map = orthogonal_map + if use_trivialization: + self.register_buffer("base", None) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + n, k = X.size(-2), X.size(-1) + transposed = n < k + if transposed: + X = X.transpose(-2, -1) + n, k = k, n + # Here n > k and X is a tall matrix + if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley: + # We just need n x k - k(k-1)/2 parameters + X = X.tril() + if n != k: + # Embed into a square matrix + X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + A = X - X.transpose(-2, -1).conj() + # A is skew-symmetric (or skew-hermitian) + if self.orthogonal_map == _OrthMaps.matrix_exp: + Q = torch.matrix_exp(A) + elif self.orthogonal_map == _OrthMaps.cayley: + # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} + Id = torch.eye(n, dtype=A.dtype, device=A.device) + Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)) + # Q is now orthogonal (or unitary) of size (..., n, n) + if n != k: + Q = Q[..., :k] + # Q is now the size of the X (albeit perhaps transposed) + else: + # X is real here, as we do not support householder with complex numbers + A = X.tril(diagonal=-1) + tau = 2. / (1. + (A * A).sum(dim=-2)) + Q = torch.linalg.householder_product(A, tau) + # The diagonal of X is 1's and -1's + # We do not want to differentiate through this or update the diagonal of X hence the casting + Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) + + if hasattr(self, "base"): + Q = self.base @ Q + if transposed: + Q = Q.transpose(-2, -1) + return Q + + @torch.autograd.no_grad() + def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: + if Q.shape != self.shape: + raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. " + f"Got a tensor of shape {Q.shape}.") + + Q_init = Q + n, k = Q.size(-2), Q.size(-1) + transpose = n < k + if transpose: + Q = Q.transpose(-2, -1) + n, k = k, n + + # We always make sure to always copy Q in every path + if not hasattr(self, "base"): + # Note [right_inverse expm cayley] + # If we do not have use_trivialization=True, we just implement the inverse of the forward + # map for the Householder. To see why, think that for the Cayley map, + # we would need to find the matrix X \in R^{n x k} such that: + # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) + # A = Y - Y.transpose(-2, -1).conj() + # cayley(A)[:, :k] + # gives the original tensor. It is not clear how to do this. + # Perhaps via some algebraic manipulation involving the QR like that of + # Corollary 2.2 in Edelman, Arias and Smith? + if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp: + raise NotImplementedError("It is not possible to assign to the matrix exponential " + "or the Cayley parametrizations when use_trivialization=False.") + + # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. + # Here Q is always real because we do not support householder and complex matrices. + # See note [Householder complex] + A, tau = torch.geqrf(Q) + # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could + # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition + # The diagonal of Q is the diagonal of R from the qr decomposition + A.diagonal(dim1=-2, dim2=-1).sign_() + # Equality with zero is ok because LAPACK returns exactly zero when it does not want + # to use a particular reflection + A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1 + return A.transpose(-2, -1) if transpose else A + else: + if n == k: + # We check whether Q is orthogonal + if not _is_orthogonal(Q): + Q = _make_orthogonal(Q) + else: # Is orthogonal + Q = Q.clone() + else: + # Complete Q into a full n x n orthogonal matrix + N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device) + Q = torch.cat([Q, N], dim=-1) + Q = _make_orthogonal(Q) + self.base = Q + + # It is necessary to return the -Id, as we use the diagonal for the + # Householder parametrization. Using -Id makes: + # householder(torch.zeros(m,n)) == torch.eye(m,n) + # Poor man's version of eye_like + neg_Id = torch.zeros_like(Q_init) + neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.) + return neg_Id + + +def orthogonal(module: Module, + name: str = 'weight', + orthogonal_map: Optional[str] = None, + *, + use_trivialization: bool = True) -> Module: + r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices. + + Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized + matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as + + .. math:: + + \begin{align*} + Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ + QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} + \end{align*} + + where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex + and the transpose when :math:`Q` is real-valued, and + :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` + and orthonormal rows otherwise. + + If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. + + The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: + + - ``"matrix_exp"``/``"cayley"``: + the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ + :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric + :math:`A` to give an orthogonal matrix. + - ``"householder"``: computes a product of Householder reflectors + (:func:`~torch.linalg.householder_product`). + + ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than + ``"householder"``, but they are slower to compute for very thin or very wide matrices. + + If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", + where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under + ``module.parametrizations.weight[0].base``. This helps the + convergence of the parametrized layer at the expense of some extra memory use. + See `Trivializations for Gradient-Based Optimization on Manifolds`_ . + + Initial value of :math:`Q`: + If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value + of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) + and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). + Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. + Otherwise, the initial value is the result of the composition of all the registered + parametrizations applied to the original tensor. + + .. note:: + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. + + + .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map + .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 + + Args: + module (nn.Module): module on which to register the parametrization. + name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. + orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. + Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. + use_trivialization (bool, optional): whether to use the dynamic trivialization framework. + Default: ``True``. + + Returns: + The original module with an orthogonal parametrization registered to the specified + weight + + Example:: + + >>> orth_linear = orthogonal(nn.Linear(20, 40)) + >>> orth_linear + ParametrizedLinear( + in_features=20, out_features=40, bias=True + (parametrizations): ModuleDict( + (weight): ParametrizationList( + (0): _Orthogonal() + ) + ) + ) + >>> Q = orth_linear.weight + >>> torch.dist(Q.T @ Q, torch.eye(20)) + tensor(4.9332e-07) + """ + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): + raise ValueError( + "Module '{}' has no parameter ot buffer with name '{}'".format(module, name) + ) + + # We could implement this for 1-dim tensors as the maps on the sphere + # but I believe it'd bite more people than it'd help + if weight.ndim < 2: + raise ValueError("Expected a matrix or batch of matrices. " + f"Got a tensor of {weight.ndim} dimensions.") + + if orthogonal_map is None: + orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder" + + orth_enum = getattr(_OrthMaps, orthogonal_map, None) + if orth_enum is None: + raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' + f'Got: {orthogonal_map}') + orth = _Orthogonal(weight, + orth_enum, + use_trivialization=use_trivialization) + parametrize.register_parametrization(module, name, orth, unsafe=True) + return module + + class _SpectralNorm(Module): def __init__( self, @@ -84,6 +360,7 @@ def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> No # Precondition assert weight_mat.ndim > 1 + for _ in range(n_power_iterations): # Spectral norm of weight equals to `u^T W v`, where `u` and `v` # are the first left and right singular vectors. @@ -92,9 +369,6 @@ def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> No dim=0, eps=self.eps, out=self._u) # type: ignore[has-type] self._v = F.normalize(torch.mv(weight_mat.t(), self._u), dim=0, eps=self.eps, out=self._v) # type: ignore[has-type] - # See above on why we need to clone - self._u = self._u.clone(memory_format=torch.contiguous_format) - self._v = self._v.clone(memory_format=torch.contiguous_format) def forward(self, weight: torch.Tensor) -> torch.Tensor: if weight.ndim == 1: @@ -104,10 +378,13 @@ def forward(self, weight: torch.Tensor) -> torch.Tensor: weight_mat = self._reshape_weight_to_matrix(weight) if self.training: self._power_method(weight_mat, self.n_power_iterations) + # See above on why we need to clone + u = self._u.clone(memory_format=torch.contiguous_format) + v = self._v.clone(memory_format=torch.contiguous_format) # The proper way of computing this should be through F.bilinear, but # it seems to have some efficiency issues: # https://github.com/pytorch/pytorch/issues/58093 - sigma = torch.dot(self._u, torch.mv(weight_mat, self._v)) + sigma = torch.dot(u, torch.mv(weight_mat, v)) return weight / sigma def right_inverse(self, value: torch.Tensor) -> torch.Tensor: @@ -146,8 +423,8 @@ def spectral_norm(module: Module, .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 .. note:: - This function is implemented using the new parametrization functionality - in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a + This function is implemented using the parametrization functionality + in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a reimplementation of :func:`torch.nn.utils.spectral_norm`. .. note:: @@ -164,13 +441,13 @@ def spectral_norm(module: Module, Args: module (nn.Module): containing module - name (str, optional): name of weight parameter + name (str, optional): name of weight parameter. Default: ``"weight"``. n_power_iterations (int, optional): number of power iterations to - calculate spectral norm + calculate spectral norm. Default: ``1``. eps (float, optional): epsilon for numerical stability in - calculating norms - dim (int, optional): dimension corresponding to number of outputs, - the default is ``0``, except for modules that are instances of + calculating norms. Default: ``1e-12``. + dim (int, optional): dimension corresponding to number of outputs. + Default: ``0``, except for modules that are instances of ConvTranspose{1,2,3}d, when it is ``1`` Returns: @@ -192,13 +469,11 @@ def spectral_norm(module: Module, >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0000, grad_fn=) """ - if not hasattr(module, name): + weight = getattr(module, name, None) + if not isinstance(weight, Tensor): raise ValueError( - "Module '{}' has no attribute with name '{}'".format(module, name) + "Module '{}' has no parameter or buffer with name '{}'".format(module, name) ) - # getattr should get the correct parametrized weight if there - # is already an parametrization registered - weight = getattr(module, name) if dim is None: if isinstance(module, (torch.nn.ConvTranspose1d, diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 332fe762b8309..d8f2a947352cf 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -129,8 +129,11 @@ def __init__( new = original for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): - new = module.right_inverse(new) - # else, we assume that right_inverse is the identity + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence): raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " @@ -209,7 +212,9 @@ def right_inverse(self, value: Tensor) -> None: for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): value = module.right_inverse(value) - # else we assume that right_inverse is the identity + else: + raise RuntimeError(f"parametrization {type(module).__name__} does not implement " + "right_inverse.") if self.is_tensor: # These exceptions should only throw when a right_inverse function does not # return the same dtype for every input, which should most likely be caused by a bug @@ -372,16 +377,12 @@ def register_parametrization( def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] - If this method is not implemented, it defaults to the identity. This method is called on the unparametrized tensor when the first parametrization - is registered. + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. - In most situations, ``right_inverse`` will be a function such that - ``forward(right_inverse(X)) == X`` (see - `right inverse `_). - Sometimes, when the parametrization is not surjective, it may be reasonable - to relax this. - This may be used to initialize the tensor, as shown in the example below. + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. It is possible for the first parametrization to depend on several inputs. This may be implemented returning a tuple of tensors from ``right_inverse`` @@ -397,6 +398,14 @@ def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] If unsafe=True, then right_inverse will be called if the tensor is not parametrized, and nothing will be called otherwise. + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + .. warning:: If a parametrization depends on several inputs, :func:`~register_parametrization` @@ -483,25 +492,29 @@ def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] f"parametrization(module.{tensor_name}).shape: {X.shape}" ) if hasattr(parametrization, "right_inverse"): - Z = parametrization.right_inverse(X) # type: ignore[operator] - if not isinstance(Z, Tensor): - raise ValueError( - f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" - ) - if Z.dtype != Y.dtype: - raise ValueError( - "The tensor returned by parametrization.right_inverse must have the same dtype " - f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" - f"module.{tensor_name}.dtype: {Y.dtype}\n" - f"returned dtype: {Z.dtype}" - ) - if Z.shape != Y.shape: - raise ValueError( - "The tensor returned by parametrization.right_inverse must have the same shape " - f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" - f"module.{tensor_name}.shape: {Y.shape}\n" - f"returned shape: {Z.shape}" - ) + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) # else right_inverse is assumed to be the identity # add the new parametrization to the parametrization list diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index c859419cf38a5..b726b2b55e8b6 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -13,7 +13,7 @@ ir_version = _C._onnx.IR_VERSION producer_name = "pytorch" producer_version = _C._onnx.PRODUCER_VERSION -constant_folding_opset_versions = [9, 10, 11, 12, 13] +constant_folding_opset_versions = [9, 10, 11, 12, 13, 14] class ExportTypes: @@ -30,11 +30,11 @@ def _export(*args, **kwargs): def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, - input_names=None, output_names=None, aten=False, - operator_export_type=None, opset_version=None, _retain_param_name=True, - do_constant_folding=True, example_outputs=None, strip_doc_string=True, - dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, - enable_onnx_checker=True, use_external_data_format=False): + input_names=None, output_names=None, operator_export_type=None, + opset_version=None, _retain_param_name=True, do_constant_folding=True, + example_outputs=None, strip_doc_string=True, dynamic_axes=None, + keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, + use_external_data_format=False): r""" Exports a model into ONNX format. If ``model`` is not a :class:`torch.jit.ScriptModule` nor a :class:`torch.jit.ScriptFunction`, this runs @@ -116,9 +116,12 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM input nodes of the graph, in order. output_names (list of str, default empty list): names to assign to the output nodes of the graph, in order. - aten (bool, default False): [DEPRECATED. use operator_export_type] equivalent to - setting ``operator_export_type=OperatorExportTypes.ONNX_ATEN``. - operator_export_type (enum, default OperatorExportTypes.ONNX): + operator_export_type (enum, default None): + + None usually means ``OperatorExportTypes.ONNX``. + However if PyTorch was built with ``-DPYTORCH_ONNX_CAFFE2_BUNDLE``, None means + ``OperatorExportTypes.ONNX_ATEN_FALLBACK``. + * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops (in the default opset domain). * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops @@ -303,9 +306,8 @@ def forward(self, x): from torch.onnx import utils return utils.export(model, args, f, export_params, verbose, training, - input_names, output_names, aten, - operator_export_type, opset_version, _retain_param_name, - do_constant_folding, example_outputs, + input_names, output_names, operator_export_type, opset_version, + _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 5b378ecc214ce..13bc4800a6700 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -298,7 +298,7 @@ def _select_helper(g, self, dim, index, apply_reshape=True): elif index_dim is not None and apply_reshape: if index_dim == 0: # Index is a scalar. Reshape it to a size 1 tensor. - index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1]))) + index = _reshape_helper(g, index, g.op("Constant", value_t=torch.LongTensor([1]))) index_scalar_type = index.type().scalarType() if index_scalar_type is None or index_scalar_type not in ["Long", "Int"]: @@ -367,7 +367,7 @@ def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None): if not _is_value(k): k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) else: - k = g.op("Reshape", k, g.op("Constant", value_t=torch.tensor([1]))) + k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) if _export_onnx_opset_version <= 10: if not largest: _unimplemented("TopK", "Ascending is not supported") @@ -704,6 +704,48 @@ def _index_fill_reshape_helper(g, self, dim, index): expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) return expanded_index_shape, expanded_index +# When using reshape helper (opset_version >= 14), if reshape has -1, +# allowzero cannot be set to 1 +def _reshape_helper(g, input, shape, allowzero=0): + shape = _maybe_get_const(shape, "is") + if not _is_value(shape): + shape = g.op("Constant", value_t=torch.LongTensor(shape)) + if _export_onnx_opset_version <= 13: + return g.op("Reshape", input, shape) + else: + warnings.warn("allowzero=0 by default. In order to honor zero value in shape use allowzero=1") + return g.op("Reshape", input, shape, allowzero_i=allowzero) + +def _batchnorm_helper(g, input, weight, bias, running_mean, running_var): + from torch.onnx.symbolic_opset9 import _var_mean + batch_size = _get_tensor_dim_size(input, 0) + channel_size = _get_tensor_dim_size(input, 1) + + if weight is None or _is_none(weight): + if channel_size is None: + raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown " + "channel size.") + weight_value = torch.tensor([1.] * channel_size).type( + "torch." + input.type().scalarType() + "Tensor") + weight = g.op("Constant", value_t=weight_value) + if bias is None or _is_none(bias): + if channel_size is None: + raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown " + "channel size.") + bias_value = torch.tensor([0.] * channel_size).type( + "torch." + input.type().scalarType() + "Tensor") + bias = g.op("Constant", value_t=bias_value) + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if running_mean is None or _is_none(running_mean) or running_var is None or _is_none(running_var): + assert batch_size is not None and channel_size is not None + reshape_in = _reshape_helper(g, input, + g.op("Constant", value_t=torch.tensor([batch_size, channel_size, -1], + dtype=torch.int64))) + trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean(g, trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, False) + return weight, bias, running_mean, running_var def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name): if divisor_override and divisor_override.node().kind() != "prim::Constant": @@ -713,19 +755,23 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na padding = tuple(tuple_fn(padding)) return padding -def assert_training_mode(op_mode, op_name): + +def check_training_mode(op_train_mode, op_name): global _training_mode - op_mode = True if op_mode == 1 else False - if op_mode != _training_mode: - op_mode = "training " if op_mode else "inference" + op_train_mode = True if op_train_mode == 1 else False + if _training_mode is not None and op_train_mode != _training_mode: + op_mode = "training " if op_train_mode else "inference" training_mode = "training " if _training_mode else "inference" # setting the model mode could result in op_mode != _training_mode # if the model is a FuncModule. In this case we warn the user of - # the state and export depending on training_mode + # the state and export depending on op_mode + # This is to support use-cases of fixing certain layer weights + # in training. warnings.warn("ONNX export mode is set to " + training_mode + " mode, but operator " + op_name + " is set to " + - op_mode + " mode. The model will be exported in " + - training_mode + ", as specified by the export mode.") + op_mode + " mode. The operators will be exported in " + + op_mode + ", as specified by the functional operator.") + def _flatten_helper(g, input, start_dim, end_dim, dim): input_size = g.op("Shape", input) @@ -787,8 +833,8 @@ def _handle_reduce_dim_none(g, self, op_name): _default_onnx_opset_version = 9 -_onnx_main_opset = 13 -_onnx_stable_opsets = [7, 8, 9, 10, 11, 12] +_onnx_main_opset = 14 +_onnx_stable_opsets = [7, 8, 9, 10, 11, 12, 13] _export_onnx_opset_version = _default_onnx_opset_version diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index eaa49c29e1546..53440f15928ee 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -179,7 +179,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) - values = g.op("Reshape", values, values_shape) + values = sym_help._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): @@ -266,12 +266,12 @@ def masked_select(g, self, mask): def masked_scatter(g, self, mask, source): - from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size + from torch.onnx.symbolic_opset9 import nonzero, expand_as, size index = nonzero(g, expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. - source = view(g, source, torch.LongTensor([-1])) + source = sym_help._reshape_helper(g, source, torch.LongTensor([-1])) source = sym_help._slice_helper(g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), @@ -453,9 +453,9 @@ def _prepare_onnx_paddings(g, dim, pad): # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], # [..., 0, dim_n-1_end, dim_n_end]] # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] - paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))) + paddings = sym_help._reshape_helper(g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))) paddings = g.op("Transpose", torch.onnx.symbolic_opset10.flip(g, paddings, [0]), perm_i=[1, 0]) - paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1]))) + paddings = sym_help._reshape_helper(g, paddings, g.op("Constant", value_t=torch.tensor([-1]))) padding_c = g.op("Cast", paddings, to_i=sym_help.cast_pytorch_to_onnx["Long"]) return padding_c @@ -695,7 +695,7 @@ def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding # Broadcast and add kernel staring positions (indices) with # kernel_grid along dim d, to get block indices along dim d blocks_d_indices = sym_help._unsqueeze_helper(g, blocks_d_indices, [0]) # Reshape to [1, -1] - kernel_mask = g.op("Reshape", kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))) + kernel_mask = sym_help._reshape_helper(g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))) block_mask = g.op("Add", blocks_d_indices, kernel_mask) return block_mask @@ -766,7 +766,7 @@ def im2col(g, input, kernel_size, dilation, padding, stride): output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) output = g.op("Gather", output, blocks_col_indices, axis_i=4) output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) - return g.op("Reshape", output, output_shape) + return sym_help._reshape_helper(g, output, output_shape) def narrow(g, input, dim, start, length): @@ -894,109 +894,6 @@ def chunk(g, self, chunks, dim): chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) return split(g, self, chunk_vec, dim) -def repeat_interleave(g, self, repeats, dim=None, output_size=None): - from torch.onnx.symbolic_opset9 import reshape - input = self - final_dim = dim - # if dim is None flatten - # By default, use the flattened input array, and return a flat output array - if sym_help._is_none(dim): - input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1]))) - dim = 0 - else: - dim = sym_help._maybe_get_scalar(dim) - - repeats_dim = sym_help._get_tensor_rank(repeats) - repeats_sizes = sym_help._get_tensor_sizes(repeats) - input_sizes = sym_help._get_tensor_sizes(input) - if repeats_dim is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "repeats rank.") - if repeats_sizes is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "repeats size.") - if input_sizes is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "input size.") - # Handle cases where dim is negative - if dim < 0: - dim += len(input_sizes) - - output_sizes = input_sizes.copy() - perm_i = [0] - for idx, input_size in enumerate(input_sizes): - perm_i.append(idx + 1) - if input_size is None: - output_sizes[idx], input_sizes[idx] = 0, -1 - perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0] - - # Cases when repeats is a single value tensor and dim has unknown input size - if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0: - if not sym_help._is_tensor(repeats): - repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - reps = sym_help._size_helper(g, input, dim) - reps = unsqueeze(g, reps, 0) - repeats = g.op("Expand", repeats, reps) - # There are cases when the repeats are 1-d tensor with multiple repeats, but dim - # provided along one of the dynamic axes provided. A simple example would be - # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 - # Now, repeat interleaving can be performed in pytorch when the value of * matches - # with the number of elements in repeat, for example if * -> 2, number of repeats - # should be 2 as well. - else: - return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) - - reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), - value_t=torch.tensor([1], dtype=torch.long)) - r_splits = split(g, repeats, reps_like, 0) - i_splits = split(g, input, reps_like, dim) - - output_sizes[dim], input_sizes[dim] = -1, 1 - - # Create a loop to iterate over each value along the dimension - # and perform individual interleaving using the repeats tensor - # Loop is of the following pattern - # input (trip_count, cond) - # int trip_count = ...; - # bool cond = ...; - # for (int i=0; i < trip_count && cond; ++i) { - # cond = ...; - # } - - # Loop conditions - loop_condition = g.op("Constant", value_t=torch.tensor(1)) - loop_condition = g.op("Cast", loop_condition, to_i=9) - loop_len = reps - loop = g.op("Loop", loop_len, loop_condition) - - # Loop inputs - loop_block = _add_block(loop.node()) - block_input_iter = _add_input_to_block(loop_block) - cond = _add_input_to_block(loop_block) - - r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) - i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) - - i_split = unsqueeze(loop_block, i_split, dim + 1) - r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), - r_split, - loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))] - r_concat = loop_block.op("Concat", *r_concat, axis_i=0) - i_split = expand(loop_block, i_split, r_concat, None) - i_split = reshape(loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))) - - # Loop outputs - cond_out = loop_block.op("Cast", loop_condition, to_i=9) - _add_output_to_block(loop_block, cond_out) - _add_output_to_block(loop_block, i_split) - loop_out = loop.node().output() - - # In this loop, the outputs are scan outputs and are concatenated along - # the zero'th dimension (by default). In order to avoid this and concatenate - # along the dimension provided, some post-processing is required - loop_out = g.op("Transpose", loop_out, perm_i=perm_i) - return reshape(g, loop_out, g.op("Constant", value_t=torch.LongTensor(output_sizes))) - def normal(g, loc, scale, seed): # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 58420a2bc7749..ab39325709ea9 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -4,6 +4,7 @@ from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block from sys import maxsize from torch.onnx.symbolic_opset9 import permute, _reshape_from_tensor +import warnings # EDITING THIS FILE? READ THIS FIRST! @@ -25,11 +26,12 @@ def outer(g, input, other): @parse_args("v", "f", "i") def dropout(g, input, p, train): - sym_help.assert_training_mode(train, "dropout") + sym_help.check_training_mode(train, "dropout") # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op - if not sym_help._training_mode: + if not train: return input - + warnings.warn("Dropout is a training op and should not be exported in inference mode. " + "For inference, make sure to call eval() on the model and to export it with param training=False.") p = g.op("Constant", value_t=torch.tensor(p)) t = g.op("Constant", value_t=torch.tensor(True)) r, _ = g.op("Dropout", input, p, t, outputs=2) @@ -63,7 +65,7 @@ def nll_loss_nd(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) -def cross_entropy_loss(g, self, target, weight, reduction, ignore_index): +def cross_entropy_loss(g, self, target, weight, reduction, ignore_index, label_smoothing): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] @@ -71,6 +73,10 @@ def cross_entropy_loss(g, self, target, weight, reduction, ignore_index): reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] + label_smoothing = sym_help._maybe_get_const(label_smoothing, "f") + if label_smoothing > 0.0: + raise RuntimeError("Unsupported: ONNX does not support label_smoothing") + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = sym_help._maybe_get_const(ignore_index, "i") @@ -123,8 +129,7 @@ def celu(g, self, alpha): def argmax(g, input, dim, keepdim): if sym_help._is_none(dim): - from torch.onnx.symbolic_opset9 import reshape - flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) + flattened = sym_help._reshape_helper(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op("ArgMax", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, "i") @@ -134,8 +139,7 @@ def argmax(g, input, dim, keepdim): def argmin(g, input, dim, keepdim): if sym_help._is_none(dim): - from torch.onnx.symbolic_opset9 import reshape - flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) + flattened = sym_help._reshape_helper(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op("ArgMin", flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, "i") diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 7f20833571a53..0baf785757702 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -5,7 +5,9 @@ import torch import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _unimplemented -from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero +from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand +from torch.onnx.symbolic_opset11 import unsqueeze +from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block # EDITING THIS FILE? READ THIS FIRST! @@ -196,3 +198,117 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None): # user's modules. splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + +def repeat_interleave(g, self, repeats, dim=None, output_size=None): + input = self + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if sym_help._is_none(dim): + input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + dim = 0 + else: + dim = sym_help._maybe_get_scalar(dim) + + repeats_dim = sym_help._get_tensor_rank(repeats) + repeats_sizes = sym_help._get_tensor_sizes(repeats) + input_sizes = sym_help._get_tensor_sizes(input) + if repeats_dim is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "repeats rank.") + if repeats_sizes is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "repeats size.") + if input_sizes is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "input size.") + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + print(output_sizes, input_sizes) + + cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None) + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = sym_help._size_helper(g, input, dim) + reps = unsqueeze(g, reps, 0) + # Check if repeats vector is a single integer value + # or a single dimension tensor with non-dynamic values + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if not sym_help._is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + repeats = g.op("Expand", repeats, reps) + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + elif cond_dynamic_repeats: + repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))) + repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long)) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, input, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=9) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + loop = g.op("Loop", loop_len, loop_condition, final_splits) + + # Loop inputs + loop_block = _add_block(loop.node()) + block_input_iter = _add_input_to_block(loop_block) + cond = _add_input_to_block(loop_block) + final_splits = _add_input_to_block(loop_block) + + r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) + + i_split = unsqueeze(loop_block, i_split, dim + 1) + r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), + r_split, + loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))] + r_concat = loop_block.op("Concat", *r_concat, axis_i=0) + i_split = expand(loop_block, i_split, r_concat, None) + i_split = sym_help._reshape_helper(loop_block, i_split, + g.op("Constant", value_t=torch.LongTensor(output_sizes))) + final_splits = loop_block.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_block.op("Cast", loop_condition, to_i=9) + _add_output_to_block(loop_block, cond_out) + _add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py new file mode 100644 index 0000000000000..d4775b553da8d --- /dev/null +++ b/torch/onnx/symbolic_opset14.py @@ -0,0 +1,54 @@ +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +# This file exports ONNX ops for opset 14 +import torch + +import torch.onnx.symbolic_helper as sym_help +from torch.onnx.symbolic_helper import parse_args + +# Note [ONNX operators that are added/updated in opset 14] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# New operators: +# HardSwish, Trilu +# +# Updated operators: +# Reshape +# Add, Sub, Mul, Div +# GRU, LSTM, RNN +# BatchNorm, Cumsum, Relu + +@parse_args("v") +def hardswish(g, self): + return g.op("HardSwish", self) + +@parse_args("v", "i") +def tril(g, self, diagonal, out=None): + k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64)) + return g.op("Trilu", self, k, upper_i=0) + +@parse_args("v", "i") +def triu(g, self, diagonal, out=None): + k = g.op("Constant", value_t=torch.tensor(diagonal, dtype=torch.int64)) + return g.op("Trilu", self, k, upper_i=1) + +@parse_args("v", "v") +def reshape(g, self, shape): + return sym_help._reshape_helper(g, self, shape) + +@parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): + sym_help.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = sym_help._batchnorm_helper(g, input, weight, bias, running_mean, running_var) + out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + training_mode_i=0 if not training else 1, + outputs=1 if not training else 3) + if not training: + return out + else: + res, new_running_mean, new_running_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + return res diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 36c1753ab252b..70bb8282570e2 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -70,11 +70,11 @@ def _shape_as_tensor(g, input): def _reshape_from_tensor(g, input, shape): if (isinstance(shape, list)): shape = g.op("Concat", *shape, axis_i=0) - return g.op("Reshape", input, shape) + return reshape(g, input, shape) def reshape(g, self, shape): - return view(g, self, shape) + return sym_help._reshape_helper(g, self, shape) def reshape_as(g, self, other): @@ -461,7 +461,7 @@ def expand(g, self, size, implicit): # Expand with -1 dim value means dim is unchanged. # Since onnx::expand supports two-way broadcasting, # -1 dim value can be exported to onnx as 1 - size = view(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))) + size = sym_help._reshape_helper(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1]))) dtype = 4 # dim type is int64 ones = ones_like(g, size, dtype) neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) @@ -566,17 +566,12 @@ def permute(g, self, dims): def view(g, self, size): - size = sym_help._maybe_get_const(size, "is") - if sym_help._is_value(size): - shape = size - else: - shape = g.op("Constant", value_t=torch.LongTensor(size)) - return g.op("Reshape", self, shape) + return reshape(g, self, size) def view_as(g, self, other): shape = g.op("Shape", other) - return g.op("Reshape", self, shape) + return reshape(g, self, shape) def prim_ConstantSplit(g, self, split_size, dim): @@ -1348,38 +1343,13 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr @parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled): - sym_help.assert_training_mode(training, "batch_norm") - batch_size = sym_help._get_tensor_dim_size(input, 0) - channel_size = sym_help._get_tensor_dim_size(input, 1) - - if weight is None or sym_help._is_none(weight): - if channel_size is None: - raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown " - "channel size.") - weight_value = torch.tensor([1.] * channel_size).type( - "torch." + input.type().scalarType() + "Tensor") - weight = g.op("Constant", value_t=weight_value) - if bias is None or sym_help._is_none(bias): - if channel_size is None: - raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown " - "channel size.") - bias_value = torch.tensor([0.] * channel_size).type( - "torch." + input.type().scalarType() + "Tensor") - bias = g.op("Constant", value_t=bias_value) - # If track_running_stats is set to False batch statistics are instead used during evaluation time - if running_mean is None or sym_help._is_none(running_mean) or running_var is None or sym_help._is_none(running_var): - assert batch_size is not None and channel_size is not None - reshape_in = g.op("Reshape", input, - g.op("Constant", value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64))) - trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) - running_var, running_mean = _var_mean(g, trans_in, - g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), - False, False) + sym_help.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = sym_help._batchnorm_helper(g, input, weight, bias, running_mean, running_var) out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, - outputs=1 if not sym_help._training_mode else 5) - if not sym_help._training_mode: + outputs=1 if not training else 5) + if not training: return out else: res, new_running_mean, new_running_var, saved_mean, saved_var = out @@ -1654,12 +1624,12 @@ def exp(g, self): @parse_args("v", "f", "i") def dropout(g, input, p, train): - sym_help.assert_training_mode(train, "dropout") + sym_help.check_training_mode(train, "dropout") # in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op - if not sym_help._training_mode: + if not train: return input warnings.warn("Dropout is a training op and should not be exported in inference mode. " - "Make sure to call eval() on the model, and to export it with param training=False.") + "For inference, make sure to call eval() on the model and to export it with param training=False.") r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) return r @@ -1771,7 +1741,7 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False): input_list = list() for t in sym_help._unpack_list(data): shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) - t = g.op("Reshape", t, shape_reference) + t = sym_help._reshape_helper(g, t, shape_reference) t = g.op("Cast", t, to_i=sym_help.scalar_type_to_onnx[dtype]) input_list.append(t) return g.op("Concat", *input_list, axis_i=0) @@ -2060,7 +2030,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): # if dim is None flatten # By default, use the flattened input array, and return a flat output array if sym_help._is_none(dim): - input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) dim = 0 else: dim = sym_help._maybe_get_scalar(dim) @@ -2088,7 +2058,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): if not sym_help._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if input_sizes[dim] == 0: - return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11, + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size") else: reps = input_sizes[dim] @@ -2097,8 +2067,11 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): # Cases where repeats is a 1 dim Tensor elif repeats_dim == 1: if input_sizes[dim] == 0: - return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11, + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size") + if repeats_sizes[0] is None: + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, + "Unsupported for cases with dynamic repeats") assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim" reps = repeats_sizes[0] else: @@ -2115,7 +2088,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1:]))] r_concat = g.op("Concat", *r_concat, axis_i=0) i_split = expand(g, i_split, r_concat, None) - i_split = reshape(g, i_split, g.op("Constant", value_t=torch.LongTensor(input_sizes))) + i_split = sym_help._reshape_helper(g, i_split, g.op("Constant", value_t=torch.LongTensor(input_sizes)), allowzero=0) final_splits.append(i_split) return g.op("Concat", *final_splits, axis_i=dim) @@ -2128,12 +2101,17 @@ def pixel_shuffle(g, self, upscale_factor): if any([i is None for i in dims[1:]]): return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size") output_channel = dims[1] // upscale_factor // upscale_factor - after_view = view(g, self, g.op("Constant", value_t=torch.tensor([-1, output_channel, upscale_factor, - upscale_factor, dims[2], dims[3]]))) + after_view = sym_help._reshape_helper(g, self, + g.op("Constant", value_t=torch.tensor([-1, output_channel, + upscale_factor, upscale_factor, + dims[2], dims[3]])), + allowzero=0) after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) - return view(g, after_transpose, - g.op("Constant", value_t=torch.tensor([-1, output_channel, dims[2] * upscale_factor, - dims[3] * upscale_factor]))) + return sym_help._reshape_helper(g, after_transpose, + g.op("Constant", value_t=torch.tensor([-1, output_channel, + dims[2] * upscale_factor, + dims[3] * upscale_factor])), + allowzero=0) def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, @@ -2277,7 +2255,8 @@ def retrieve_state(x, start, end): # Transpose, and then combining it with hidden_size # with Reshape. prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) - prev_output = g.op("Reshape", prev_output, g.op("Constant", value_t=torch.LongTensor([0, 0, -1]))) + prev_output = sym_help._reshape_helper(g, prev_output, + g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), allowzero=0) else: prev_output = sym_help._squeeze_helper(g, prev_output, [1]) @@ -2316,6 +2295,17 @@ def lstm(g, *args): return _lstm_full(g, *args) +def lstm_cell(g, self, hidden, w_ih, w_hh, b_ih, b_hh): + input = sym_help._unsqueeze_helper(g, self, [0]) + hidden = sym_help._unpack_list(hidden) + hidden = [sym_help._unsqueeze_helper(g, x, [0]) for x in hidden] + weight = (w_ih, w_hh, b_ih, b_hh) if sym_help._is_tensor(b_ih) else (w_ih, w_hh) + has_biases = True if sym_help._is_tensor(b_ih) else False + _, h_outs, c_outs = _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers=1, + dropout=0, train=0, bidirectional=False, batch_first=False) + return sym_help._squeeze_helper(g, h_outs, [0]), sym_help._squeeze_helper(g, c_outs, [0]) + + def _one_hidden_rnn(kind): @parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first): @@ -2514,7 +2504,7 @@ def narrow(g, input, dim, start, length): def argmax(g, input, dim, keepdim): if sym_help._is_none(dim): - flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) + flattened = sym_help._reshape_helper(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op("ArgMax", flattened, axis_i=0, keepdims_i=False) else: dim = _parse_arg(dim, "i") @@ -2524,7 +2514,7 @@ def argmax(g, input, dim, keepdim): def argmin(g, input, dim, keepdim): if sym_help._is_none(dim): - flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1]))) + flattened = sym_help._reshape_helper(g, input, g.op("Constant", value_t=torch.tensor([-1]))) return g.op("ArgMin", flattened, axis_i=0, keepdims_i=False) else: dim = _parse_arg(dim, "i") @@ -2857,7 +2847,7 @@ def try_mask_to_index(index): folded_adv_idx_shape_list = [g.op("Constant", value_t=torch.LongTensor([-1]))] \ + [dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices] folded_adv_idx_shape = g.op("Concat", *folded_adv_idx_shape_list, axis_i=0) - self = g.op("Reshape", self, folded_adv_idx_shape) + self = sym_help._reshape_helper(g, self, folded_adv_idx_shape) # Transpose folded advanced indexed axis to its original location. adv_idx_permute = list(range(1, adv_idx_indices[0] + 1)) \ @@ -2876,7 +2866,7 @@ def try_mask_to_index(index): *[dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices], axis_i=0) - return g.op("Reshape", self, final_shape) + return sym_help._reshape_helper(g, self, final_shape) @parse_args("v", "is", "i") @@ -2908,7 +2898,8 @@ def baddbmm(g, self, batch1, batch2, beta, alpha): def meshgrid(g, tensor_list): - tensors = [view(g, t, g.op("Constant", value_t=torch.LongTensor([-1]))) for t in sym_help._unpack_list(tensor_list)] + tensors = [sym_help._reshape_helper(g, t, g.op("Constant", value_t=torch.LongTensor([-1]))) + for t in sym_help._unpack_list(tensor_list)] tensors_shape = [g.op("Shape", t) for t in tensors] out_shape = g.op("Concat", *tensors_shape, axis_i=0) out = [] @@ -2948,7 +2939,8 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): return _unimplemented("group_norm", "unknown input rank") # 0 in the shape list keeps dimension value unchanged. shape = [0, num_groups, -1] - input_reshaped = g.op("Reshape", input, g.op("Constant", value_t=torch.LongTensor(shape))) + input_reshaped = sym_help._reshape_helper(g, input, + g.op("Constant", value_t=torch.LongTensor(shape))) # C is always divisible by num_groups # Due to shape difference. we need to apply weight and bias after @@ -2959,7 +2951,7 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): "torch." + input.type().scalarType() + "Tensor")) norm_reshaped = g.op("InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps) - norm = g.op("Reshape", norm_reshaped, g.op("Shape", input)) + norm = sym_help._reshape_helper(g, norm_reshaped, g.op("Shape", input)) if weight is None or weight.node().mustBeNone(): weight_value = torch.tensor([1.]).type( @@ -3016,7 +3008,7 @@ def item(g, self): def take(g, self, index): - self_flattened = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + self_flattened = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) out = index_select(g, self_flattened, 0, index) out = reshape_as(g, out, index) return out @@ -3060,7 +3052,7 @@ def kl_div(g, input, target, reduction, log_target): def as_strided(g, self, sizes, strides, offset=None): sizes = sym_help._maybe_get_const(sizes, "is") rank = len(strides) - self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) + self_1d = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) ind: Optional[torch.Tensor] if not sym_help._is_value(sizes): ind = torch.tensor([0], dtype=torch.long) @@ -3077,7 +3069,8 @@ def as_strided(g, self, sizes, strides, offset=None): r_size = [1] * rank r_size[i] = -1 size = select(g, sizes, g.op("Constant", value_t=torch.tensor([0])), g.op("Constant", value_t=torch.tensor(i))) - tmp_ind = g.op("Reshape", arange(g, size, 4, None, None, None), g.op("Constant", value_t=torch.tensor(r_size))) + tmp_ind = sym_help._reshape_helper(g, arange(g, size, 4, None, None, None), + g.op("Constant", value_t=torch.tensor(r_size))) tmp_ind = g.op("Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride]))) if ind is None: ind = tmp_ind @@ -3145,6 +3138,10 @@ def mv(g, self, vec): return matmul(g, self, vec) +def dot(g, self, other): + return matmul(g, self, other) + + @parse_args('v', 'v') def fill(g, self, value): dtype = self.type().scalarType() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f5dc2f2270165..41ba20f3ad102 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -71,15 +71,12 @@ def select_model_mode_for_export(model, mode): def export(model, args, f, export_params=True, verbose=False, training=None, - input_names=None, output_names=None, aten=False, - operator_export_type=None, opset_version=None, _retain_param_name=True, - do_constant_folding=True, example_outputs=None, strip_doc_string=True, - dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, + input_names=None, output_names=None, operator_export_type=None, + opset_version=None, _retain_param_name=True, do_constant_folding=True, + example_outputs=None, strip_doc_string=True, dynamic_axes=None, + keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False): - if aten: - assert operator_export_type is None - operator_export_type = OperatorExportTypes.ONNX_ATEN - elif operator_export_type is None: + if operator_export_type is None: if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE: operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK else: @@ -529,18 +526,11 @@ def _model_to_graph(model, args, verbose=False, def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None, - input_names=None, output_names=None, aten=False, - operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE, - example_outputs=None, google_printer=False, - opset_version=None, _retain_param_name=True, + input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX, + export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, + google_printer=False, opset_version=None, _retain_param_name=True, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, do_constant_folding=True): - if aten: - assert operator_export_type is None - assert aten - operator_export_type = OperatorExportTypes.ONNX_ATEN - elif operator_export_type is None: - operator_export_type = OperatorExportTypes.ONNX return _export_to_pretty_string(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, google_printer, diff --git a/torch/optim/adam.py b/torch/optim/adam.py index d7313be75f8fb..ea2ceaff67057 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -6,9 +6,37 @@ class Adam(Optimizer): r"""Implements Adam algorithm. - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - The implementation of the L2 penalty follows changes proposed in - `Decoupled Weight Decay Regularization`_. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. Args: params (iterable): iterable of parameters to optimize or dicts defining @@ -25,8 +53,6 @@ class Adam(Optimizer): .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 78a8cfad0d637..42f7b511c54a5 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -328,7 +328,7 @@ def get_lr(self): return [group['lr'] * lmbda(self.last_epoch) for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] else: - return list(self.base_lrs) + return [group['lr'] for group in self.optimizer.param_groups] class StepLR(_LRScheduler): @@ -427,25 +427,78 @@ def _get_closed_form_lr(self): for base_lr in self.base_lrs] -class WarmUpLR(_LRScheduler): - """Decays the learning rate of each parameter group by either a small constant - or linearly increasing small warmup factor until the number of epoch reaches a - pre-defined milestone: warmup_iters. Notice that such decay can happen - simultaneously with other changes to the learning rate from outside this scheduler. +class ConstantLR(_LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. - warmup_factor (float): The number we multiply learning rate in the first epoch. - If the warming up method is constant, the multiplication factor of the - learning rate stays the same in all epochs, but, in the linear case, it - starts increasing in the following epochs. Default: 1./3. - warmup_iters (int): The number of warming up steps. Default: 5. - warmup_method (str): One of `constant` and `linear`. In `constant` mode, the - learning rate will be multiplied with a small constant until a milestone - defined in warmup_iters. In the `linear` case, the multiplication factor - starts with warmup_factor in the first epoch then linearly increases to - reach 1. in the epoch number warmup_iters. Default: `linear`. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler decays the learning rate. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + if factor > 1.0 or factor < 0: + raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') + + self.factor = factor + self.total_iters = total_iters + super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters or + (self.last_epoch != self.total_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.last_epoch == self.total_iters): + return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs] + + +class LinearLR(_LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. last_epoch (int): The index of the last epoch. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. @@ -457,24 +510,25 @@ class WarmUpLR(_LRScheduler): >>> # lr = 0.0375 if epoch == 2 >>> # lr = 0.04375 if epoch == 3 >>> # lr = 0.005 if epoch >= 4 - >>> scheduler = WarmUpLR(self.opt, warmup_factor=0.5, warmup_iters=4, warmup_method="linear") + >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ - def __init__(self, optimizer, warmup_factor=1.0 / 3, warmup_iters=5, warmup_method="linear", - last_epoch=-1, verbose=False): - if warmup_method not in ("constant", "linear"): - raise ValueError( - "Only 'constant' or 'linear' warmup_method accepted, but " - "got {}".format(warmup_method) - ) - self.warmup_factor = warmup_factor - self.warmup_iters = warmup_iters - self.warmup_method = warmup_method - super(WarmUpLR, self).__init__(optimizer, last_epoch, verbose) + def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, + verbose=False): + if start_factor > 1.0 or start_factor < 0: + raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super(LinearLR, self).__init__(optimizer, last_epoch, verbose) def get_lr(self): if not self._get_lr_called_within_step: @@ -482,25 +536,18 @@ def get_lr(self): "please use `get_last_lr()`.", UserWarning) if self.last_epoch == 0: - return [group['lr'] * self.warmup_factor for group in self.optimizer.param_groups] + return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] - if (self.last_epoch > self.warmup_iters or - (self.warmup_method == "constant" and self.last_epoch != self.warmup_iters)): + if (self.last_epoch > self.total_iters): return [group['lr'] for group in self.optimizer.param_groups] - if (self.warmup_method == "constant" and self.last_epoch == self.warmup_iters): - return [group['lr'] * (1.0 / self.warmup_factor) for group in self.optimizer.param_groups] - - return [group['lr'] * (1. + (1.0 - self.warmup_factor) / - (self.warmup_iters * self.warmup_factor + (self.last_epoch - 1) * (1 - self.warmup_factor))) + return [group['lr'] * (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) for group in self.optimizer.param_groups] def _get_closed_form_lr(self): - return [base_lr * (self.warmup_factor + - (1 - self.warmup_factor) * min(self.warmup_iters, self.last_epoch) / - self.warmup_iters * (self.warmup_method == "linear") + - (self.last_epoch >= self.warmup_iters) * (1 - self.warmup_factor) * - (self.warmup_method == "constant")) + return [base_lr * (self.start_factor + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) for base_lr in self.base_lrs] @@ -526,7 +573,7 @@ def get_lr(self): "please use `get_last_lr()`.", UserWarning) if self.last_epoch == 0: - return self.base_lrs + return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] * self.gamma for group in self.optimizer.param_groups] @@ -586,7 +633,7 @@ def get_lr(self): "please use `get_last_lr()`.", UserWarning) if self.last_epoch == 0: - return self.base_lrs + return [group['lr'] for group in self.optimizer.param_groups] elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: return [group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 @@ -603,6 +650,44 @@ def _get_closed_form_lr(self): for base_lr in self.base_lrs] +class ChainedScheduler(_LRScheduler): + """Chains list of learning rate schedulers. It takes a list of chainable learning + rate schedulers and performs consecutive step() functions belong to them by just + one call. + + Args: + schedulers (list): List of chained schedulers. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, schedulers): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "ChainedScheduler expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + self.schedulers = list(schedulers) + + def step(self): + for scheduler in self.schedulers: + scheduler.step() + + class ReduceLROnPlateau(object): """Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 821407e3ccca6..9b1b8ea63eed7 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -18,8 +18,11 @@ class StepLR(_LRScheduler): class MultiStepLR(_LRScheduler): def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float=..., last_epoch: int=...) -> None: ... -class WarmUpLR(_LRScheduler): - def __init__(self, optimizer: Optimizer, warmup_factor: float=..., warmup_iters: int=..., warmup_method: str=..., last_epoch: int=...) -> None: ... +class ConstantLR(_LRScheduler): + def __init__(self, optimizer: Optimizer, factor: float=..., total_iters: int=..., last_epoch: int=...) -> None: ... + +class LinearLR(_LRScheduler): + def __init__(self, optimizer: Optimizer, start_factor: float=..., end_factor: float=..., total_iters: int=..., last_epoch: int=...) -> None: ... class ExponentialLR(_LRScheduler): def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ... diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 55a790610c5a5..deaaf20b1d710 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -6,7 +6,34 @@ class NAdam(Optimizer): r"""Implements NAdam algorithm. - It has been proposed in `Incorporating Nesterov Momentum into Adam`_. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\ + &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex] + & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_. Args: params (iterable): iterable of parameters to optimize or dicts defining diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 02f1cc265937b..79f72f041822b 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -100,7 +100,8 @@ def state_dict(self): * state - a dict holding current optimization state. Its content differs between optimizer classes. - * param_groups - a dict containing all parameter groups + * param_groups - a list containing all parameter groups where each + parameter group is a dict """ # Save order indices instead of Tensors param_mappings = {} diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 4aab0b3116fdb..dc72181b351f8 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -6,15 +6,44 @@ class RMSprop(Optimizer): r"""Implements RMSprop algorithm. - Proposed by G. Hinton in his - `course `_. - - The centered version first appears in `Generating Sequences + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, + \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ + &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: + \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t + \hspace{8mm} \\ + &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\ + &\hspace{5mm}if \: centered \\ + &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\ + &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\ + &\hspace{5mm}if \: \mu > 0 \\ + &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} + + g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\ + &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\ + &\hspace{5mm} else \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to + `lecture notes `_ by G. Hinton. + and centered version `Generating Sequences With Recurrent Neural Networks `_. - The implementation here takes the square root of the gradient average before adding epsilon (note that TensorFlow interchanges these two operations). The effective - learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` + learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma` is the scheduled learning rate and :math:`v` is the weighted moving average of the squared gradient. diff --git a/torch/overrides.py b/torch/overrides.py index 5a0ea6ca81737..1bb98507f18b1 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -360,6 +360,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1, torch.cartesian_prod: lambda *tensors: -1, torch.cat: lambda tensors, dim=0, out=None: -1, + torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1, torch.ceil: lambda input, out=None: -1, torch.celu: lambda input, alhpa=1., inplace=False: -1, @@ -598,6 +599,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.masked_scatter: lambda input, mask, source: -1, torch.masked_select: lambda input, mask, out=None: -1, torch.matmul: lambda input, other, out=None: -1, + torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul torch.matrix_power: lambda input, n: -1, torch.linalg.matrix_power: lambda input, n, out=None: -1, torch.matrix_rank: lambda input, tol=None, symmetric=False: -1, @@ -677,7 +679,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1), torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction="mean": -1), + reduce=None, reduction="mean", label_smoothing=0.0: -1), torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False: -1), torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, @@ -1030,6 +1032,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.retains_grad.__get__: lambda self: -1, Tensor.is_meta.__get__: lambda self: -1, Tensor.is_mlc.__get__: lambda self: -1, + Tensor.is_ort.__get__: lambda self: -1, Tensor.is_mkldnn.__get__: lambda self: -1, Tensor.is_quantized.__get__: lambda self: -1, Tensor.is_sparse.__get__: lambda self: -1, diff --git a/torch/quantization/fx/_lower_to_native_backend.py b/torch/quantization/fx/_lower_to_native_backend.py new file mode 100644 index 0000000000000..a5518996bc44e --- /dev/null +++ b/torch/quantization/fx/_lower_to_native_backend.py @@ -0,0 +1,14 @@ +from torch.fx import subgraph_rewriter +from .graph_module import QuantizedGraphModule +from .quantized_fusion_patterns_and_replacements import get_fbgemm_patterns_and_replacements + +def _lower_to_native_backend(model: QuantizedGraphModule) -> QuantizedGraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same + operator signature so they can be lowered with the same function + """ + module_dict = dict(model.named_modules()) + for pattern, replacement in get_fbgemm_patterns_and_replacements(): + subgraph_rewriter.replace_pattern(model, pattern, replacement) + model.graph.lint() + return model diff --git a/torch/quantization/fx/backend_config_dict/__init__.py b/torch/quantization/fx/backend_config_dict/__init__.py new file mode 100644 index 0000000000000..edb2b956851b7 --- /dev/null +++ b/torch/quantization/fx/backend_config_dict/__init__.py @@ -0,0 +1,4 @@ +from .fbgemm import get_fbgemm_backend_config_dict + +def validate_backend_config_dict(backend_config_dict): + return "quant_patterns" in backend_config_dict diff --git a/torch/quantization/fx/backend_config_dict/fbgemm.py b/torch/quantization/fx/backend_config_dict/fbgemm.py new file mode 100644 index 0000000000000..4f40b100f0b78 --- /dev/null +++ b/torch/quantization/fx/backend_config_dict/fbgemm.py @@ -0,0 +1,11 @@ +from ..pattern_utils import get_default_quant_patterns + +def get_fbgemm_backend_config_dict(): + """ Get the backend config dictionary for fbgemm backend + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + """ + # TODO: add output_activation_post_process_map + return { + "quant_patterns": get_default_quant_patterns() + } diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 976ca0c6aeca7..867b0b24cf7ad 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -45,6 +45,8 @@ activation_dtype, ) +from .lower_to_fbgemm import lower_to_fbgemm + # weight prepacking ops WEIGHT_PREPACK_OPS = { torch._ops.ops.quantized.linear_prepack, @@ -335,11 +337,18 @@ def node_arg_is_quantized(node_arg: Any) -> bool: else: return False - def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool: + def is_output_quantized( + node: Node, obj: QuantizeHandler, qconfig: QConfigAny, + modules: Dict[str, torch.nn.Module], is_reference=False) -> bool: """ Check if output node is quantized or not """ assert modules is not None - # by default the output for a quantizable node is expected to be quantized - quantized = True + # for some ops the output is quantized only when `is_reference` is True + # and when `is_reference` is False, it has limited qconfig + # support, for example `add` + # ideally this check should not happen here, it should happen either in + # prepare or during lowering, we don't need this check + # after the default path is changed to produce reference patterns + quantized = obj.is_output_quantized(qconfig, is_reference) # Need to get correct quantized/non-quantized state forn the output # of FixedQParamsQuantizeHandler @@ -454,7 +463,7 @@ def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> Non node, qconfig, modules, quantized_graph, node_name_to_scope, load_arg, is_reference=is_reference, convert_custom_config_dict=convert_custom_config_dict) if not is_observed_standalone_module_node: - quantized = is_output_quantized(node, obj, qconfig, modules) + quantized = is_output_quantized(node, obj, qconfig, modules, is_reference) if quantized: env[node.name][activation_dtype(qconfig)] = result @@ -528,4 +537,5 @@ def load_arg_remove(a: Argument) -> Argument: model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes) if not is_reference: model = fold_weight(model, node_name_to_scope) + model = lower_to_fbgemm(model) return model diff --git a/torch/quantization/fx/lower_to_fbgemm.py b/torch/quantization/fx/lower_to_fbgemm.py new file mode 100644 index 0000000000000..fc76d135ee809 --- /dev/null +++ b/torch/quantization/fx/lower_to_fbgemm.py @@ -0,0 +1,8 @@ +from ._lower_to_native_backend import _lower_to_native_backend +from .graph_module import QuantizedGraphModule + +def lower_to_fbgemm(model: QuantizedGraphModule) -> QuantizedGraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to fbgemm + """ + return _lower_to_native_backend(model) diff --git a/torch/quantization/fx/lower_to_qnnpack.py b/torch/quantization/fx/lower_to_qnnpack.py new file mode 100644 index 0000000000000..0a0ea9cd248cd --- /dev/null +++ b/torch/quantization/fx/lower_to_qnnpack.py @@ -0,0 +1,8 @@ +from ._lower_to_native_backend import _lower_to_native_backend +from .graph_module import QuantizedGraphModule + +def lower_to_qnnpack(model: QuantizedGraphModule) -> QuantizedGraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to qnnpack + """ + return _lower_to_native_backend(model) diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index dd8501c9b8bf1..4aa9275870c26 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -9,9 +9,6 @@ QuantizeHandler, CustomModuleQuantizeHandler, StandaloneModuleQuantizeHandler, - BinaryOpQuantizeHandler, - binary_op_supported_dtypes, - binary_reference_op_supported_dtypes, ) from ..qconfig import ( QConfigAny, @@ -19,7 +16,6 @@ from .graph_module import ( is_observed_standalone_module, ) -from ..utils import get_qconfig_dtypes from typing import Any, Dict, List, Callable, Optional, Tuple, Set @@ -135,60 +131,15 @@ def record_match(pattern, node, matched): if node.name not in match_map and node.name not in all_matched: for pattern, value in patterns.items(): if is_match(modules, node, pattern): - skip_this_match = False - if value is BinaryOpQuantizeHandler: - - # to properly check for dtype support, we need to - # navigate to the base node of an add-relu or mul-relu - # pattern - base_node = node - if ( - (node.op == 'call_function' and - node.target is torch.nn.functional.relu) or - (node.op == 'call_module' and - isinstance(modules[node.target], torch.nn.ReLU)) - ): - base_node = node.args[0] - - this_node_qconfig = \ - qconfig_map[base_node.name] - if this_node_qconfig: - dtypes = get_qconfig_dtypes(this_node_qconfig) - # TODO(future PR): update the pattern to quantize - # handler logic to take this into account. - - - # This needs to handle 3 cases - # 1) op and dtype is in either [is_ref or non-ref] list -> don't skip - # 2) op is not in either list (i.e. relu) -> don't skip - # 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip - - # note: the value of is_reference is unknown at prepare, so we have to cover both cases - # handle is_reference = False - skip_match_not_is_reference = ( - (base_node.target in binary_op_supported_dtypes) and - (dtypes not in binary_op_supported_dtypes[base_node.target]) - ) - - # handle is_reference = True - supported_is_reference = ( - (base_node.target in binary_reference_op_supported_dtypes) and - (dtypes in binary_reference_op_supported_dtypes[base_node.target]) - ) - - # only skip if not reference says skip and is_reference doesn't support - skip_this_match = skip_match_not_is_reference and not supported_is_reference - - if not skip_this_match: - matched: List[Any] = [] - record_match(pattern, node, matched) - for n in matched: - match_map[n.name] = ( - node, matched, pattern, value(node, modules), # type: ignore[operator] - qconfig_map[n.name]) - all_matched.add(n.name) - # break after finding the first match - break + matched: List[Any] = [] + record_match(pattern, node, matched) + for n in matched: + match_map[n.name] = ( + node, matched, pattern, value(node, modules), # type: ignore[operator] + qconfig_map[n.name]) + all_matched.add(n.name) + # break after finding the first match + break # add custom module instances to the match result assert modules is not None diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index ab137487b3cc8..d2bb96ab7a5c0 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -15,7 +15,7 @@ ) from torch.fx.node import Argument -from ..qconfig import QConfigAny +from ..qconfig import QConfigAny, qconfig_equals from .qconfig_utils import ( convert_dict_to_ordered_dict, generate_qconfig_map, @@ -42,7 +42,6 @@ from .pattern_utils import ( MatchResult, - get_default_quant_patterns, get_default_output_activation_post_process_map, ) @@ -84,10 +83,13 @@ weight_dtype, ) +from .backend_config_dict import get_fbgemm_backend_config_dict +from .backend_config_dict import validate_backend_config_dict + from typing import Any, Callable, Dict, List, Optional, Tuple, Union def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool: - return node.op == "call_module" and \ + return isinstance(node, torch.fx.Node) and node.op == "call_module" and \ is_activation_post_process(modules[str(node.target)]) def node_arg_is_weight(node: Node, arg: Any) -> bool: @@ -195,7 +197,7 @@ def update_qconfig_for_fusion( # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict for op in ops: - if object_type_dict.get(op, None) != fused_qconfig: + if not qconfig_equals(object_type_dict.get(op, None), fused_qconfig): raise LookupError("During fusion, we need to specify the same " + f"qconfigs for both modules in {module_type}.") @@ -324,7 +326,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( graph, node_name_to_target_dtype, qhandler, prepare_custom_config_dict) new_arg_to_return.append(new_inner_arg) - return new_arg_to_return + return type(arg)(new_arg_to_return) if not isinstance(arg, Node): return arg @@ -772,6 +774,8 @@ def maybe_make_input_output_share_observers( # we need to navigate up to the first observer iteration_guard = 0 while not is_activation_post_process_node(first_arg_arg, modules): + if not isinstance(first_arg_arg, Node): + return False # did not find an activation_post_process for the op if first_arg_arg.op == "placeholder": return False @@ -1112,6 +1116,7 @@ def prepare( node_name_to_scope: Dict[str, Tuple[str, type]], prepare_custom_config_dict: Optional[Dict[str, Any]] = None, equalization_qconfig_dict: Optional[Dict[str, Any]] = None, + backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False) -> ObservedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -1137,6 +1142,10 @@ def prepare( prepare_custom_config_dict = {} if equalization_qconfig_dict is None: equalization_qconfig_dict = {} + if backend_config_dict is None: + backend_config_dict = get_fbgemm_backend_config_dict() + + validate_backend_config_dict(backend_config_dict) additional_quant_patterns = \ prepare_custom_config_dict.get("additional_quant_pattern", {}) @@ -1150,8 +1159,9 @@ def prepare( # ((, ): # ), # } + quant_patterns = backend_config_dict["quant_patterns"] patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict( - get_default_quant_patterns(), additional_quant_patterns) + quant_patterns, additional_quant_patterns) convert_dict_to_ordered_dict(qconfig_dict) convert_dict_to_ordered_dict(equalization_qconfig_dict) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index a68eea2bbf44c..3f54a6a1e2743 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -155,6 +155,15 @@ def get_activation_ctr( """ return qconfig.activation + def is_output_quantized(self, qconfig, is_reference): + """ Returns true if the output node of convert is quantized + when is_reference is False, we would return float node when a certain dtype + combination is not supported (since fbgemm/qnnpack only support certain dtype + combinations), so the output may be float, but when is_reference is True, + we support all dtype combinations so the output will always be quantized. + """ + return True + @abstractmethod def convert(self, @@ -180,34 +189,52 @@ def convert(self, # tuple (activation_dtype, weight_dtype, compute_dtype) # these are supported types for common binary ops like add/mul etc. -binary_op_all_dtypes = [ +all_dtypes = [ (torch.quint8, torch.qint8, None), (torch.float16, torch.float16, None), ] -binary_op_float16_dtypes = [ +fp16_dtypes = [ (torch.float16, torch.float16, None) ] -binary_op_int8_dtypes = [ +int8_dtypes = [ (torch.quint8, torch.qint8, None), ] binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = { - operator.add: binary_op_all_dtypes, - torch.add: binary_op_all_dtypes, - operator.mul: binary_op_all_dtypes, - torch.mul: binary_op_all_dtypes, - torch.bmm: binary_op_float16_dtypes, - torch.sub: binary_op_float16_dtypes, - operator.sub: binary_op_float16_dtypes, - torch.div: binary_op_float16_dtypes, - operator.truediv: binary_op_float16_dtypes, - torch.sum: binary_op_float16_dtypes + operator.add: all_dtypes, + torch.add: all_dtypes, + operator.mul: all_dtypes, + torch.mul: all_dtypes, + torch.bmm: fp16_dtypes, + torch.sub: fp16_dtypes, + operator.sub: fp16_dtypes, + torch.div: fp16_dtypes, + operator.truediv: fp16_dtypes, } -binary_reference_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = { - torch.bmm: binary_op_int8_dtypes, - operator.add: binary_op_int8_dtypes, - torch.add: binary_op_int8_dtypes, - operator.mul: binary_op_int8_dtypes, - torch.mul: binary_op_int8_dtypes, + +default_op_supported_dtypes = { + torch.nn.ConvTranspose1d: int8_dtypes, + torch.nn.ConvTranspose2d: int8_dtypes, + torch.nn.ELU: int8_dtypes, + torch.nn.LeakyReLU: int8_dtypes, + torch.nn.Hardswish: int8_dtypes, + torch.nn.InstanceNorm1d: int8_dtypes, + torch.nn.InstanceNorm2d: int8_dtypes, + torch.nn.InstanceNorm3d: int8_dtypes, + torch.nn.LayerNorm: all_dtypes, + torch.nn.SiLU: fp16_dtypes, + torch.nn.Mish: fp16_dtypes, + torch.nn.GELU: int8_dtypes, + torch.nn.Softmax: int8_dtypes, + torch.nn.functional.elu: int8_dtypes, + torch.nn.functional.hardswish: int8_dtypes, + torch.nn.functional.instance_norm: int8_dtypes, + torch.nn.functional.layer_norm: all_dtypes, + torch.nn.functional.leaky_relu: int8_dtypes, + torch.nn.functional.silu: fp16_dtypes, + torch.nn.functional.mish: fp16_dtypes, + torch.nn.functional.gelu: int8_dtypes, + torch.nn.functional.softmax: int8_dtypes, + torch.sum: fp16_dtypes, } QAT_CONV_MODULE_CLASSES = \ @@ -266,7 +293,6 @@ def _get_name(): @register_quant_pattern(torch.sub) @register_quant_pattern(torch.mul) @register_quant_pattern(torch.div) -@register_quant_pattern(torch.sum) @register_quant_pattern(torch.bmm) @register_quant_pattern((torch.nn.ReLU, operator.add)) @register_quant_pattern((torch.nn.ReLU, operator.mul)) @@ -344,6 +370,13 @@ def input_output_observed(self): # for x + y where x and y are scalars, we do not observe anything return self.num_tensor_args > 0 + def is_output_quantized(self, qconfig, is_reference): + dtypes = get_qconfig_dtypes(qconfig) + if not is_reference: + return self.binary_op in binary_op_supported_dtypes and \ + dtypes in binary_op_supported_dtypes[self.binary_op] + return True + def convert(self, node: Node, qconfig: QConfigAny, @@ -361,11 +394,14 @@ def convert(self, dtypes = get_qconfig_dtypes(qconfig) - if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \ - dtypes in binary_reference_op_supported_dtypes[self.binary_op]: - if dtypes in binary_op_int8_dtypes: - # make sure both inputs are quantized to torch.quint8 - load_arg(quantized={0: torch.quint8, 1: torch.quint8})(self.binary_op_node.args) + if is_reference: + act_dtype = activation_dtype(qconfig) + if act_dtype == torch.float: + return quantized_graph.node_copy(node, load_arg(quantized=torch.float)) + else: + if self.num_tensor_args == 2: + # make sure both inputs are quantized to act_dtype + load_arg(quantized={0: act_dtype, 1: act_dtype})(self.binary_op_node.args) args = load_arg(quantized=torch.float)(self.binary_op_node.args) kwargs = load_arg(quantized=torch.float)(self.binary_op_node.kwargs) op_out = quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=torch.float)) @@ -384,12 +420,6 @@ def modified_load_arg(n: Node): return quantize_node( op_out, activation_post_process, node, modules, quantized_graph, node_name_to_scope, is_input=False) - else: - warnings.warn( - "No implementation found for dtype combination: {}" - "for op {} with is_reference={} despite it being listed as supported" - "this should not happen".format(dtypes, self.binary_op, is_reference)) - return quantized_graph.node_copy(node, load_arg(quantized=torch.float)) elif not is_reference and self.binary_op in binary_op_supported_dtypes and \ dtypes in binary_op_supported_dtypes[self.binary_op]: if dtypes in [(torch.quint8, torch.qint8, None)]: @@ -445,15 +475,10 @@ def modified_load_arg(n: Node): "dtype combination: {} is not " "supported by {} for is_reference={}. " "Supported non-reference dtype combinations are: {} " - "Supported reference dtype combinations are: {}" "".format(dtypes, self.binary_op, is_reference, - binary_op_supported_dtypes[self.binary_op], - ( - [] if self.binary_op not in binary_reference_op_supported_dtypes.keys() - else binary_reference_op_supported_dtypes[self.binary_op] - ) + binary_op_supported_dtypes[self.binary_op] ) ) if self.relu_node: @@ -613,19 +638,22 @@ def convert(self, # and qparam is a dictionary of # {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or # {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization + float_conv = self.conv + fused_conv = None if isinstance( - self.conv, + float_conv, QAT_CONV_MODULE_CLASSES): # case 1. converting qat conv module to # a float conv module, we need to attch # weight fake_quant to the conv module, # weight fake_quant is assumed to be run during # QAT so we don't need to run it again here - float_conv = self.conv.to_float() + float_conv = self.conv.to_float() # type: ignore[operator] # change qat conv to conv parent_name, name = _parent_name(self.conv_node.target) setattr(modules[parent_name], name, float_conv) if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv float_conv = float_conv[0] weight_post_process = self.conv.weight_fake_quant else: @@ -633,15 +661,28 @@ def convert(self, # to float conv module, we need to attach # weight observer to the conv module and run it # with conv weight - float_conv = self.conv - if isinstance(self.conv, torch.nn.intrinsic._FusedModule): - float_conv = self.conv[0] + if isinstance(float_conv, torch.nn.intrinsic._FusedModule): + fused_conv = float_conv + float_conv = float_conv[0] # type: ignore[index] assert qconfig is not None weight_post_process = qconfig.weight() # run weight observer - weight_post_process(float_conv.weight) + weight_post_process(float_conv.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) - _to_reference(float_conv, weight_qparams) + # hardcoded for now, TODO: expose the api to user, + # we can have a map from module to reference module + # and allow user to register new ones + qconv_cls = get_static_quant_module_class( + type(float_conv), is_reference=is_reference) + ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined] + # if the parent is a fused conv (Sequential), we can replace the first + # item to ref conv, otherwise we can update + # the conv instance in the module tree + if fused_conv is not None: + fused_conv[0] = ref_conv + else: + parent_name, name = _parent_name(self.conv_node.target) + setattr(modules[parent_name], name, ref_conv) op_out = quantized_graph.create_node( 'call_module', self.conv_node.target, @@ -844,6 +885,7 @@ def convert(self, # Get the float linear and attach qscheme and qparams # the the module float_linear = self.linear + fused_linear = None if isinstance(float_linear, (torch.nn.qat.Linear, torch.nn.intrinsic.qat.LinearReLU)): float_linear = float_linear.to_float() # change qat linear to linear @@ -851,10 +893,12 @@ def convert(self, setattr(modules[parent_name], name, float_linear) # Attach weight fake quant to the linear module if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear float_linear = float_linear[0] weight_post_process = self.linear.weight_fake_quant else: if isinstance(float_linear, torch.nn.intrinsic.LinearReLU): + fused_linear = float_linear float_linear = self.linear[0] # type: ignore[index] # Attach the weight observer to the module weight_post_process = qconfig.weight() # type: ignore[union-attr] @@ -862,7 +906,21 @@ def convert(self, weight_post_process(float_linear.weight) # type: ignore[operator] weight_qparams = get_qparam_dict(weight_post_process) - _to_reference(float_linear, weight_qparams) + # TODO: include the configuration in backend_config_dict + # we can have a map from module to reference module + # and allow user to register new ones + qlinear_cls = get_static_quant_module_class( + type(float_linear), is_reference=is_reference) + ref_linear = qlinear_cls.from_float(float_linear, weight_qparams) + + # if the parent is a fused linear (Sequential), we can replace the first + # item to ref linear, otherwise we can update + # the linear instance in the module tree + if fused_linear is not None: + fused_linear[0] = ref_linear + else: + parent_name, name = _parent_name(self.linear_node.target) + setattr(modules[parent_name], name, ref_linear) op_out = quantized_graph.create_node( 'call_module', self.linear_node.target, @@ -997,9 +1055,17 @@ def convert(self, elif dtypes in [(torch.float32, torch.qint8, torch.quint8), (torch.float32, torch.float16, None)]: # choose linear dynamic or linear dynamic fp16 op based on weight dtype - qlinear_op = torch.ops.quantized.linear_dynamic \ - if weight_dtype == torch.qint8 \ - else torch.ops.quantized.linear_dynamic_fp16 + if weight_dtype == torch.qint8: + if self.relu_node: + qlinear_op = torch.ops.quantized.linear_relu_dynamic + else: + qlinear_op = torch.ops.quantized.linear_dynamic + else: + if self.relu_node: + qlinear_op = torch.ops.quantized.linear_relu_dynamic_fp16 + else: + qlinear_op = torch.ops.quantized.linear_dynamic_fp16 + linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0]) qlinear_args = (linear_input, packed_weight) # type: ignore[assignment] op_out = quantized_graph.create_node( @@ -1008,8 +1074,6 @@ def convert(self, # TODO: may need to change the key to Node regenerate the map in each transformation, # since we might not be able to rely on the name node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name] - if self.relu_node: - op_out = quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {}) return op_out else: assert dtypes == (torch.float16, torch.float16, None) @@ -1226,6 +1290,7 @@ def convert(self, # until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig # @register_quant_pattern(torch.nn.functional.gelu) # @register_quant_pattern(torch.nn.functional.softmax) +@register_quant_pattern(torch.sum) class DefaultNodeQuantizeHandler(QuantizeHandler): """ Common quantized op, first input and first output will be quantized """ @@ -1239,6 +1304,13 @@ def __init__( elif node.op == "call_module": self.op = type(modules[str(node.target)]) + def is_output_quantized(self, qconfig, is_reference): + dtypes = get_qconfig_dtypes(qconfig) + if not is_reference: + return self.op in default_op_supported_dtypes and \ + dtypes in default_op_supported_dtypes[self.op] + return True + def convert(self, node: Node, qconfig: QConfigAny, @@ -1256,46 +1328,12 @@ def convert(self, convert_custom_config_dict = {} additional_static_quant_mapping = convert_custom_config_dict.get("static", {}) - all_dtypes = [ - (torch.quint8, torch.qint8, None), - (torch.float16, torch.float16, None) - ] - int8_dtypes = [ - (torch.quint8, torch.qint8, None) - ] - fp16_dtypes = [ - (torch.float16, torch.float16, None) - ] - supported_dtypes = { - torch.nn.ConvTranspose1d: int8_dtypes, - torch.nn.ConvTranspose2d: int8_dtypes, - torch.nn.ELU: int8_dtypes, - torch.nn.LeakyReLU: int8_dtypes, - torch.nn.Hardswish: int8_dtypes, - torch.nn.InstanceNorm1d: int8_dtypes, - torch.nn.InstanceNorm2d: int8_dtypes, - torch.nn.InstanceNorm3d: int8_dtypes, - torch.nn.LayerNorm: all_dtypes, - torch.nn.SiLU: fp16_dtypes, - torch.nn.Mish: fp16_dtypes, - torch.nn.GELU: int8_dtypes, - torch.nn.Softmax: int8_dtypes, - torch.nn.functional.elu: int8_dtypes, - torch.nn.functional.hardswish: int8_dtypes, - torch.nn.functional.instance_norm: int8_dtypes, - torch.nn.functional.layer_norm: all_dtypes, - torch.nn.functional.leaky_relu: int8_dtypes, - torch.nn.functional.silu: fp16_dtypes, - torch.nn.functional.mish: fp16_dtypes, - torch.nn.functional.gelu: int8_dtypes, - torch.nn.functional.softmax: int8_dtypes, - } dtypes = get_qconfig_dtypes(qconfig) - if not is_reference and dtypes not in supported_dtypes[self.op]: + if not is_reference and dtypes not in default_op_supported_dtypes[self.op]: warnings.warn( "dtype combination: {} is not " "supported by {} " - "supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op])) + "supported dtype combinations are: {}".format(dtypes, self.op, default_op_supported_dtypes[self.op])) return quantized_graph.node_copy(node, load_arg(quantized=torch.float)) # TODO: make helper functions for (torch.quint8, torch.qint8, None) if not is_reference: @@ -1448,6 +1486,9 @@ def convert(self, @register_quant_pattern(torch.nn.AvgPool3d) @register_quant_pattern(torch.nn.Dropout) @register_quant_pattern(torch.nn.Hardtanh) +@register_quant_pattern(torch.nn.MaxPool1d) +@register_quant_pattern(torch.nn.MaxPool2d) +@register_quant_pattern(torch.nn.MaxPool3d) @register_quant_pattern(torch.nn.ReLU) @register_quant_pattern(torch.nn.ReLU6) @register_quant_pattern(torch.adaptive_avg_pool1d) @@ -1457,12 +1498,16 @@ def convert(self, @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) +@register_quant_pattern(torch.nn.functional.max_pool1d) +@register_quant_pattern(torch.nn.functional.max_pool2d) +@register_quant_pattern(torch.nn.functional.max_pool3d) @register_quant_pattern(torch.nn.functional.relu) @register_quant_pattern(torch.nn.functional.relu6) @register_quant_pattern(torch.avg_pool1d) @register_quant_pattern(torch._C._nn.avg_pool2d) @register_quant_pattern(torch._C._nn.avg_pool3d) @register_quant_pattern(torch.clamp) +@register_quant_pattern(torch.flatten) @register_quant_pattern(torch.max) @register_quant_pattern(torch.mean) @register_quant_pattern(torch.min) @@ -1497,7 +1542,9 @@ def convert(self, load_arg: Callable, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: - if is_reference: + # always produce reference pattern for relu + is_relu = node.op == "call_function" and node.target == torch.nn.functional.relu + if is_reference or is_relu: # when activation dtype is torch.float, the node does not require # observation # e.g. dynamic quantization or weight_only quantization @@ -1555,15 +1602,8 @@ def convert(self, # module attribute like module._QUANTIZED_INPUT_INDEXES return quantized_graph.node_copy(node, load_arg(quantized=None)) -@register_quant_pattern(torch.nn.MaxPool1d) -@register_quant_pattern(torch.nn.MaxPool2d) -@register_quant_pattern(torch.nn.MaxPool3d) @register_quant_pattern(torch.nn.Identity) -@register_quant_pattern(torch.nn.functional.max_pool1d) -@register_quant_pattern(torch.nn.functional.max_pool2d) -@register_quant_pattern(torch.nn.functional.max_pool3d) @register_quant_pattern(torch.chunk) -@register_quant_pattern(torch.flatten) @register_quant_pattern(torch.transpose) @register_quant_pattern(torch.repeat_interleave) @register_quant_pattern(torch.sort) diff --git a/torch/quantization/fx/quantized_fusion_patterns_and_replacements.py b/torch/quantization/fx/quantized_fusion_patterns_and_replacements.py new file mode 100644 index 0000000000000..07c109ec4f922 --- /dev/null +++ b/torch/quantization/fx/quantized_fusion_patterns_and_replacements.py @@ -0,0 +1,31 @@ +import torch + +def relu_inplace_pattern(x, scale, zero_point): + x = x.dequantize() + x = torch.nn.functional.relu(x, inplace=True) + x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) + return x + +def relu_non_inplace_pattern(x, scale, zero_point): + x = x.dequantize() + x = torch.nn.functional.relu(x, inplace=False) + x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) + return x + +def relu_replacement(x, scale, zero_point): + x = torch.nn.functional.relu(x) + return x + + +def _get_all_patterns_and_replacements(): + return [ + (relu_inplace_pattern, relu_replacement), + (relu_non_inplace_pattern, relu_replacement) + ] + + +def get_fbgemm_patterns_and_replacements(): + return _get_all_patterns_and_replacements() + +def get_qnnpack_patterns_and_replacements(): + return _get_all_patterns_and_replacements() diff --git a/torch/quantization/ns/graph_passes.py b/torch/quantization/ns/graph_passes.py index 36e737e3baf4b..51eb6c24ef3fb 100644 --- a/torch/quantization/ns/graph_passes.py +++ b/torch/quantization/ns/graph_passes.py @@ -361,6 +361,7 @@ def _insert_copy_of_subgraph_a_after_input_node_c( if isinstance(input_node_c, Node): graph_c = input_node_c.graph else: + assert isinstance(input_node_c, list) graph_c = input_node_c[0].graph # create a sequential list of the subgraphs' nodes from start to end, @@ -450,6 +451,7 @@ def _insert_copy_of_node_a_after_input_node_c( if isinstance(input_node_c, Node): graph_c = input_node_c.graph else: + assert isinstance(input_node_c, list) graph_c = input_node_c[0].graph # generically handle all args and kwargs except for the input diff --git a/torch/quantization/ns/mappings.py b/torch/quantization/ns/mappings.py index 2a7c859347f3d..e97d77119d00e 100644 --- a/torch/quantization/ns/mappings.py +++ b/torch/quantization/ns/mappings.py @@ -8,6 +8,7 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.intrinsic.qat as nniqat import torch.nn.intrinsic as nni import torch.nn.qat as nnqat @@ -70,6 +71,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: nnq.Linear, nni.LinearReLU, nniq.LinearReLU, + nniqd.LinearReLU, nnqat.Linear, nnqd.Linear, nniqat.LinearReLU, @@ -419,6 +421,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: # uncomment below # operator.add, # operator.mul, + torch.sum, ]) FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() @@ -528,6 +531,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nniqat.ConvReLU2d, nniqat.ConvReLU3d, nniqat.LinearReLU, + nniqd.LinearReLU, ]) MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = set([ diff --git a/torch/quantization/ns/utils.py b/torch/quantization/ns/utils.py index 678f60a00c8cc..62397d0de0f94 100644 --- a/torch/quantization/ns/utils.py +++ b/torch/quantization/ns/utils.py @@ -317,15 +317,15 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]: def get_target_type_str(node: Node, gm: GraphModule) -> str: """ Returns a string representation of the type of the function or module - pointed to by this node, or '' for other op types. + pointed to by this node, or '' for other node types. """ target_type = "" if node.op in ("call_function", "call_method"): - target_type = str(node.target) + target_type = torch.typename(node.target) elif node.op == "call_module": assert isinstance(node.target, str) target_mod = getattr_from_fqn(gm, node.target) - target_type = str(type(target_mod)) + target_type = torch.typename(target_mod) return target_type diff --git a/torch/quantization/ns/weight_utils.py b/torch/quantization/ns/weight_utils.py index 724cdc7a40ae6..36e183efe1d8e 100644 --- a/torch/quantization/ns/weight_utils.py +++ b/torch/quantization/ns/weight_utils.py @@ -231,6 +231,8 @@ def extract_weight_from_node( op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() ref_node_type = get_target_type_str(node, gm) + # for extracting weights, these are always the same + prev_node_type = ref_node_type if node.op == 'call_function': function_mapping = op_to_type_to_weight_extraction_fn['call_function'] @@ -241,7 +243,7 @@ def extract_weight_from_node( 'type': res_type, 'values': [weight], 'prev_node_name': node.name, - 'prev_node_target_type': str(node.target), + 'prev_node_target_type': prev_node_type, 'ref_node_name': node.name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, @@ -261,7 +263,7 @@ def extract_weight_from_node( 'type': res_type, 'values': [weight], 'prev_node_name': node.name, - 'prev_node_target_type': str(type(mod)), + 'prev_node_target_type': prev_node_type, 'ref_node_name': node.name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index 15eb174f021b9..ae89b4a50b70a 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -209,3 +209,20 @@ def configure_constructor_to_put_obs_on_module_device(original_constructor): return QConfig(activation, weight) else: return QConfigDynamic(activation, weight) + + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + def partial_equals(p1, p2): + same = p1.func == p2.func + same = same and p1.args == p2.args + return same and p1.keywords == p2.keywords + + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + return partial_equals(q1.activation.p, q2.activation.p) and partial_equals(q1.weight.p, q2.weight.p) + except AttributeError: + return q1 == q2 diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 6179398b7398a..6851ba7bd447d 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,7 +6,7 @@ import torch.nn.functional as F import torch.nn.intrinsic as nni import torch.nn.intrinsic.quantized as nniq -import torch.nn.intrinsic.quantized._reference as nniqr +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.intrinsic.qat as nniqat import torch.nn.quantized as nnq import torch.nn.quantized._reference as nnqr @@ -24,27 +24,10 @@ # Default map for swapping float module to reference quantized modules DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Linear: nnqr.Linear, nn.Conv1d: nnqr.Conv1d, nn.Conv2d: nnqr.Conv2d, nn.Conv3d: nnqr.Conv3d, - nn.Linear: nnqr.Linear, - nni.ConvReLU1d: nniqr.ConvReLU1d, - nni.ConvReLU2d: nniqr.ConvReLU2d, - nni.ConvReLU3d: nniqr.ConvReLU3d, - nni.LinearReLU: nniqr.LinearReLU, - # QAT Modules - nnqat.Linear: nnqr.Linear, - nnqat.Conv2d: nnqr.Conv2d, - nnqat.Conv3d: nnqr.Conv3d, - nniqat.ConvBn1d: nnqr.Conv1d, - nniqat.ConvBn2d: nnqr.Conv2d, - nniqat.ConvBn3d: nnqr.Conv3d, - nniqat.ConvBnReLU1d: nniqr.ConvReLU1d, - nniqat.ConvBnReLU2d: nniqr.ConvReLU2d, - nniqat.ConvBnReLU3d: nniqr.ConvReLU3d, - nniqat.ConvReLU2d: nniqr.ConvReLU2d, - nniqat.ConvReLU3d: nniqr.ConvReLU3d, - nniqat.LinearReLU: nniqr.LinearReLU, } # Default map for swapping float module to quantized ones @@ -122,6 +105,7 @@ nn.GRU: nnqd.GRU, nn.LSTMCell: nnqd.LSTMCell, nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, } # Allowlist for propagating the qconfig diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index aa8edbba64e49..2dd98ea6ffe4c 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -140,8 +140,9 @@ def create_node(self, kind : str, target : Target, return node def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None, - equalization_qconfig_dict: Dict[str, Any] = None, + prepare_custom_config_dict: Optional[Dict[str, Any]] = None, + equalization_qconfig_dict: Optional[Dict[str, Any]] = None, + backend_config_dict: Optional[Dict[str, Any]] = None, is_standalone_module: bool = False) -> ObservedGraphModule: r""" Internal helper function for prepare_fx Args: @@ -203,7 +204,8 @@ def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any, def _prepare_standalone_module_fx( model: torch.nn.Module, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule: + prepare_custom_config_dict: Dict[str, Any] = None, + backend_config_dict: Dict[str, Any] = None) -> GraphModule: r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the parent module. standalone_module means it a submodule that is not inlined in parent module, @@ -224,7 +226,7 @@ def _prepare_standalone_module_fx( same as input_quantized_idxs configuration provided for the standalone module """ - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict, is_standalone_module=True) def fuse_fx(model: torch.nn.Module, fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule: @@ -265,8 +267,9 @@ def fuse_fx(model: torch.nn.Module, def prepare_fx( model: torch.nn.Module, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None, - equalization_qconfig_dict: Dict[str, Any] = None) -> ObservedGraphModule: + prepare_custom_config_dict: Optional[Dict[str, Any]] = None, + equalization_qconfig_dict: Optional[Dict[str, Any]] = None, + backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule: r""" Prepare a model for post training static quantization Args: @@ -392,6 +395,11 @@ def prepare_fx( with a similar structure as qconfig_dict except it will contain configurations specific to equalization techniques such as input-weight equalization. + `backend_config_dict`: a dictionary that specifies how operators are quantized + in a backend, this includes how the operaetors are observed, + supported fusion patterns, how quantize/dequantize ops are + inserted, supported dtypes etc. The structure of the dictionary is still WIP + and will change in the future, please don't use right now. Return: @@ -420,16 +428,18 @@ def calibrate(model, data_loader): torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") assert not model.training, 'prepare_fx only works for models in ' + \ 'eval mode' - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict) def prepare_qat_fx( model: torch.nn.Module, qconfig_dict: Any, - prepare_custom_config_dict: Dict[str, Any] = None) -> ObservedGraphModule: + prepare_custom_config_dict: Optional[Dict[str, Any]] = None, + backend_config_dict: Optional[Dict[str, Any]] = None) -> ObservedGraphModule: r""" Prepare a model for quantization aware training Args: `model`: torch.nn.Module model, must be in train mode `qconfig_dict`: see :func:`~torch.quantization.prepare_fx` `prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx` + `backend_config_dict`: see :func:`~torch.quantization.prepare_fx` Return: A GraphModule with fake quant modules (configured by qconfig_dict), ready for @@ -457,7 +467,7 @@ def train_loop(model, train_data): torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") assert model.training, 'prepare_qat_fx only works for models in ' + \ 'train mode' - return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict) + return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, backend_config_dict) def _convert_fx( graph_module: GraphModule, is_reference: bool, diff --git a/torch/random.py b/torch/random.py index d774634478697..f5156bf48730d 100644 --- a/torch/random.py +++ b/torch/random.py @@ -1,4 +1,5 @@ import contextlib +from typing import Generator import warnings from torch._C import default_generator @@ -65,7 +66,7 @@ def initial_seed() -> int: @contextlib.contextmanager -def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"): +def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator: """ Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. diff --git a/torch/special/__init__.py b/torch/special/__init__.py index 1f3b3fc5dc899..2fea9c6cb1b04 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -1,9 +1,12 @@ -import sys - import torch from torch._C import _add_docstr, _special # type: ignore[attr-defined] from torch._torch_docs import common_args, multi_dim_common +__all__ = ['entr', 'psi', 'digamma', 'gammaln', 'polygamma', 'erf', 'erfc', 'erfinv', + 'erfcx', 'logit', 'logsumexp', 'expit', 'exp2', 'expm1', 'xlog1py', 'xlogy', + 'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log1p', 'sinc', 'round', 'log_softmax', + 'zeta', 'multigammaln'] + Tensor = torch.Tensor entr = _add_docstr(_special.special_entr, diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 526d02c71e322..7ea18a4f9cea2 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,4 +1,5 @@ from ._core import * # noqa: F403 from ._asserts import * # noqa: F403 +from ._creation import * # noqa: F403 from ._check_kernel_launches import * # noqa: F403 from ._deprecated import * # noqa: F403 diff --git a/torch/testing/_asserts.py b/torch/testing/_asserts.py index 2de2cc0735529..073e2e2230820 100644 --- a/torch/testing/_asserts.py +++ b/torch/testing/_asserts.py @@ -44,52 +44,6 @@ def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) -def _check_complex_components_individually( - check_tensors: Callable[..., Optional[_TestingErrorMeta]] -) -> Callable[..., Optional[_TestingErrorMeta]]: - """Decorates real-valued tensor check functions to handle complex components individually. - - If the inputs are not complex, this decorator is a no-op. - - Args: - check_tensors (Callable[[Tensor, Tensor], Optional[_TestingErrorMeta]]): Tensor check function for real-valued - tensors. - """ - - @functools.wraps(check_tensors) - def wrapper( - actual: Tensor, expected: Tensor, *, equal_nan: Union[str, bool], **kwargs: Any - ) -> Optional[_TestingErrorMeta]: - if equal_nan == "relaxed": - relaxed_complex_nan = True - equal_nan = True - else: - relaxed_complex_nan = False - - if actual.dtype not in (torch.complex32, torch.complex64, torch.complex128): - return check_tensors(actual, expected, equal_nan=equal_nan, **kwargs) - - if relaxed_complex_nan: - actual, expected = [ - t.clone().masked_fill( - t.real.isnan() | t.imag.isnan(), complex(float("NaN"), float("NaN")) # type: ignore[call-overload] - ) - for t in (actual, expected) - ] - - error_meta = check_tensors(actual.real, expected.real, equal_nan=equal_nan, **kwargs) - if error_meta: - return error_meta - - error_meta = check_tensors(actual.imag, expected.imag, equal_nan=equal_nan, **kwargs) - if error_meta: - return error_meta - - return None - - return wrapper - - def _check_sparse_coo_members_individually( check_tensors: Callable[..., Optional[_TestingErrorMeta]] ) -> Callable[..., Optional[_TestingErrorMeta]]: @@ -430,10 +384,24 @@ def append_difference(msg: str, *, type: str, difference: float, index: Tuple[in return msg.strip() +def _get_comparison_dtype(dtype: torch.dtype) -> torch.dtype: + """Selects the comparison dtype based on the input dtype. + + Returns: + Highest precision dtype of the same dtype category as the input. :class:`torch.bool` is treated as integral + dtype. + """ + if dtype.is_complex: + return torch.complex128 + elif dtype.is_floating_point: + return torch.float64 + else: + return torch.int64 + + @_check_quantized @_check_sparse_coo_members_individually @_check_sparse_csr_members_individually -@_check_complex_components_individually def _check_values_close( actual: Tensor, expected: Tensor, @@ -457,7 +425,7 @@ def _check_values_close( Returns: (Optional[AssertionError]): If check did not pass. """ - dtype = torch.float64 if actual.dtype.is_floating_point else torch.int64 + dtype = _get_comparison_dtype(actual.dtype) actual = actual.to(dtype) expected = expected.to(dtype) mismatches = ~torch.isclose(actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -740,7 +708,7 @@ def assert_close( allow_subclasses: bool = True, rtol: Optional[float] = None, atol: Optional[float] = None, - equal_nan: Union[bool, str] = False, + equal_nan: bool = False, check_device: bool = True, check_dtype: bool = True, check_stride: bool = False, @@ -761,9 +729,6 @@ def assert_close( (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are only considered equal to each other if :attr:`equal_nan` is ``True``. - If :attr:`actual` and :attr:`expected` are complex-valued, they are considered close if both their real and - imaginary components are considered close according to the definition above. - If :attr:`actual` and :attr:`expected` are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. @@ -795,8 +760,7 @@ def assert_close( default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. - equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. If ``"relaxed"``, - complex values are considered as ``NaN`` if either the real **or** imaginary component is ``NaN``. + equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. @@ -956,20 +920,6 @@ def assert_close( Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) - >>> # If equal_nan=True, the real and imaginary NaN's of complex inputs have to match. - >>> expected = torch.tensor(complex(float("NaN"), 0)) - >>> actual = torch.tensor(complex(0, float("NaN"))) - >>> torch.testing.assert_close(actual, expected, equal_nan=True) - Traceback (most recent call last): - ... - AssertionError: Scalars are not close! - - Absolute difference: nan (up to 1e-05 allowed) - Relative difference: nan (up to 1.3e-06 allowed) - >>> # If equal_nan="relaxed", however, then complex numbers are treated as NaN if any - >>> # of the real or imaginary components is NaN. - >>> torch.testing.assert_close(actual, expected, equal_nan="relaxed") - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. diff --git a/torch/testing/_core.py b/torch/testing/_core.py index 9a5fb0c643097..b3cc6f163c49f 100644 --- a/torch/testing/_core.py +++ b/torch/testing/_core.py @@ -6,35 +6,14 @@ import random import math import cmath -from typing import cast, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import operator FileCheck = torch._C.FileCheck __all__ = [ "FileCheck", - "all_types", - "all_types_and", - "all_types_and_complex", - "all_types_and_complex_and", - "all_types_and_half", - "assert_allclose", - "complex_types", - "empty_types", - "floating_and_complex_types", - "floating_and_complex_types_and", - "floating_types", - "floating_types_and", - "double_types", - "floating_types_and_half", - "get_all_complex_dtypes", - "get_all_dtypes", "get_all_device_types", - "get_all_fp_dtypes", - "get_all_int_dtypes", - "get_all_math_dtypes", - "integral_types", - "integral_types_and", "make_non_contiguous", ] @@ -42,9 +21,7 @@ # False otherwise. # TODO: implement numpy-like issubdtype def is_integral(dtype: torch.dtype) -> bool: - # Skip complex/quantized types - dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()] - return dtype in dtypes and not dtype.is_floating_point + return dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) def is_quantized(dtype: torch.dtype) -> bool: return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) @@ -79,27 +56,12 @@ def _unravel_index(flat_index, shape): # Two tensors are "equal" if they are "close", in the sense of torch.allclose. # The only exceptions are complex tensors and bool tensors. # -# Complex tensors are "equal" if both the -# real and complex parts (separately) are close. This is divergent from -# torch.allclose's behavior, which compares the absolute values of the -# complex numbers instead. -# -# Using torch.allclose would be a less strict -# comparison that would allow large complex values with -# significant real or imaginary differences to be considered "equal," -# and would make setting rtol and atol for complex tensors distinct from -# other tensor types. -# # Bool tensors are equal only if they are identical, regardless of # the rtol and atol values. # # The `equal_nan` can be True or False, which maps to the True or False -# in `torch.allclose`. `equal_nan` can also be "relaxed", which means -# the complex will be compared in the relaxed mode: -# 2 + nan j == 3 + nan j ---> False when equal_nan=True -# True when equal_nan="relaxed" -def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan: Union[str, bool]) -> _compare_return_type: - assert equal_nan in {True, False, "relaxed"} +# in `torch.allclose`. +def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, equal_nan) -> _compare_return_type: debug_msg : Optional[str] # Integer (including bool) comparisons are identity comparisons # when rtol is zero and atol is less than one @@ -130,48 +92,19 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e _unravel_index(greatest_diff_index, a.shape))) return (False, debug_msg) - # Compares complex tensors' real and imaginary parts separately. - # (see NOTE Test Framework Tensor "Equality") - if a.is_complex(): - if equal_nan == "relaxed": - a = a.clone() - b = b.clone() - a.real[a.imag.isnan()] = math.nan - a.imag[a.real.isnan()] = math.nan - b.real[b.imag.isnan()] = math.nan - b.imag[b.real.isnan()] = math.nan - - real_result, debug_msg = _compare_tensors_internal(a.real, b.real, - rtol=rtol, atol=atol, - equal_nan=equal_nan) - - if not real_result: - debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg) - return (real_result, debug_msg) - - imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag, - rtol=rtol, atol=atol, - equal_nan=equal_nan) - - if not imag_result: - debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg) - return (imag_result, debug_msg) - - return (True, None) - # All other comparisons use torch.allclose directly - if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=(equal_nan in {"relaxed", True})): + if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan): return (True, None) # Gathers debug info for failed float tensor comparison # NOTE: converts to float64 to best represent differences - a_flat = a.to(torch.float64).flatten() - b_flat = b.to(torch.float64).flatten() + a_flat = a.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten() + b_flat = b.to(torch.float64 if not a.dtype.is_complex else torch.complex128).flatten() diff = torch.abs(a_flat - b_flat) # Masks close values # NOTE: this avoids (inf - inf) oddities when computing the difference - close = torch.isclose(a_flat, b_flat, rtol, atol, (equal_nan in {"relaxed", True})) + close = torch.isclose(a_flat, b_flat, rtol, atol, equal_nan) diff[close] = 0 nans = torch.isnan(diff) num_nans = nans.sum() @@ -213,7 +146,7 @@ def _helper(a, b, s) -> _compare_return_type: # Special-case for infinity comparisons # NOTE: if b is inf then allowed_diff will be inf when rtol is not 0 - if ((math.isinf(a) or math.isinf(b)) and a != b): + if ((cmath.isinf(a) or cmath.isinf(b)) and a != b): result = False msg = None @@ -229,47 +162,8 @@ def _helper(a, b, s) -> _compare_return_type: ) return result, msg - if isinstance(a, complex) or isinstance(b, complex): - a = complex(a) - b = complex(b) - - if equal_nan == "relaxed": - if cmath.isnan(a) and cmath.isnan(b): - return (True, None) - - result, msg = _helper(a.real, b.real, " the real part ") - - if not result: - return (False, msg) - - return _helper(a.imag, b.imag, " the imaginary part ") - return _helper(a, b, " ") -def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') -> None: - if not isinstance(actual, torch.Tensor): - actual = torch.tensor(actual) - if not isinstance(expected, torch.Tensor): - expected = torch.tensor(expected, dtype=actual.dtype) - if expected.shape != actual.shape: - raise AssertionError("expected tensor shape {0} doesn't match with actual tensor " - "shape {1}!".format(expected.shape, actual.shape)) - if rtol is None or atol is None: - if rtol is not None or atol is not None: - raise ValueError("rtol and atol must both be specified or both be unspecified") - rtol, atol = _get_default_tolerance(actual, expected) - - result, debug_msg = _compare_tensors_internal(actual, expected, - rtol=rtol, atol=atol, - equal_nan=equal_nan) - - if result: - return - - if msg is None or msg == '': - msg = debug_msg - - raise AssertionError(msg) def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor: if tensor.numel() <= 1: # can't make non-contiguous @@ -300,125 +194,5 @@ def make_non_contiguous(tensor: torch.Tensor) -> torch.Tensor: return input.data -# Functions and classes for describing the dtypes a function supports -# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros - -# Verifies each given dtype is a torch.dtype -def _validate_dtypes(*dtypes): - for dtype in dtypes: - assert isinstance(dtype, torch.dtype) - return dtypes - -# class for tuples corresponding to a PyTorch dispatch macro -class _dispatch_dtypes(tuple): - def __add__(self, other): - assert isinstance(other, tuple) - return _dispatch_dtypes(tuple.__add__(self, other)) - -_empty_types = _dispatch_dtypes(()) -def empty_types(): - return _empty_types - -_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) -def floating_types(): - return _floating_types - -_floating_types_and_half = _floating_types + (torch.half,) -def floating_types_and_half(): - return _floating_types_and_half - -def floating_types_and(*dtypes): - return _floating_types + _validate_dtypes(*dtypes) - -_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) -def floating_and_complex_types(): - return _floating_and_complex_types - -def floating_and_complex_types_and(*dtypes): - return _floating_and_complex_types + _validate_dtypes(*dtypes) - -_double_types = _dispatch_dtypes((torch.float64, torch.complex128)) -def double_types(): - return _double_types - -_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)) -def integral_types(): - return _integral_types - -def integral_types_and(*dtypes): - return _integral_types + _validate_dtypes(*dtypes) - -_all_types = _floating_types + _integral_types -def all_types(): - return _all_types - -def all_types_and(*dtypes): - return _all_types + _validate_dtypes(*dtypes) - -_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) -def complex_types(): - return _complex_types - -_all_types_and_complex = _all_types + _complex_types -def all_types_and_complex(): - return _all_types_and_complex - -def all_types_and_complex_and(*dtypes): - return _all_types_and_complex + _validate_dtypes(*dtypes) - -_all_types_and_half = _all_types + (torch.half,) -def all_types_and_half(): - return _all_types_and_half - -def get_all_dtypes(include_half=True, - include_bfloat16=True, - include_bool=True, - include_complex=True, - include_complex32=False - ) -> List[torch.dtype]: - dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) - if include_bool: - dtypes.append(torch.bool) - if include_complex: - dtypes += get_all_complex_dtypes(include_complex32) - return dtypes - -def get_all_math_dtypes(device) -> List[torch.dtype]: - return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), - include_bfloat16=False) + get_all_complex_dtypes() - -def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: - return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] - - -def get_all_int_dtypes() -> List[torch.dtype]: - return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] - - -def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]: - dtypes = [torch.float32, torch.float64] - if include_half: - dtypes.append(torch.float16) - if include_bfloat16: - dtypes.append(torch.bfloat16) - return dtypes - - def get_all_device_types() -> List[str]: return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - -# 'dtype': (rtol, atol) -_default_tolerances = { - 'float64': (1e-5, 1e-8), # NumPy default - 'float32': (1e-4, 1e-5), # This may need to be changed - 'float16': (1e-3, 1e-3), # This may need to be changed -} - - -def _get_default_tolerance(a, b=None) -> Tuple[float, float]: - if b is None: - dtype = str(a.dtype).split('.')[-1] # e.g. "float32" - return _default_tolerances.get(dtype, (0, 0)) - a_tol = _get_default_tolerance(a) - b_tol = _get_default_tolerance(b) - return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1])) diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py new file mode 100644 index 0000000000000..4eb10d1d5d26b --- /dev/null +++ b/torch/testing/_creation.py @@ -0,0 +1,155 @@ +""" +This module contains tensor creation utilities. +""" + +import torch +from typing import Optional, List, Tuple, Union, cast +import math + +__all__ = [ + "make_tensor", +] + +def make_tensor( + shape: Union[torch.Size, List[int], Tuple[int, ...]], + device: Union[str, torch.device], + dtype: torch.dtype, + *, + low: Optional[float] = None, + high: Optional[float] = None, + requires_grad: bool = False, + noncontiguous: bool = False, + exclude_zero: bool = False +) -> torch.Tensor: + r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with + values uniformly drawn from ``[low, high)``. + + If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable + finite values then they are clamped to the lowest or highest representable finite value, respectively. + If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`, + which depend on :attr:`dtype`. + + +---------------------------+------------+----------+ + | ``dtype`` | ``low`` | ``high`` | + +===========================+============+==========+ + | boolean type | ``0`` | ``2`` | + +---------------------------+------------+----------+ + | unsigned integral type | ``0`` | ``10`` | + +---------------------------+------------+----------+ + | signed integral types | ``-9`` | ``10`` | + +---------------------------+------------+----------+ + | floating types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + | complex types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + + Args: + shape (Tuple[int, ...]): A sequence of integers defining the shape of the output tensor. + device (Union[str, torch.device]): The device of the returned tensor. + dtype (:class:`torch.dtype`): The data type of the returned tensor. + low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is + clamped to the least representable finite value of the given dtype. When ``None`` (default), + this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is + clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value + is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``. + noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is + ignored if the constructed tensor has fewer than two elements. + exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value + depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating + point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the + :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number + whose real and imaginary parts are both the smallest positive normal number representable by the complex + type. Default ``False``. + + Raises: + ValueError: If ``low > high``. + ValueError: If either :attr:`low` or :attr:`high` is ``nan``. + TypeError: If :attr:`dtype` isn't supported by this function. + + Examples: + >>> from torch.testing import make_tensor + >>> # Creates a float tensor with values in [-1, 1) + >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + tensor([ 0.1205, 0.2282, -0.6380]) + >>> # Creates a bool tensor on CUDA + >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + tensor([[False, False], + [False, True]], device='cuda:0') + """ + def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): + """ + Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. + """ + def clamp(a, l, h): + return min(max(a, l), h) + + low = low if low is not None else default_low + high = high if high is not None else default_high + + # Checks for error cases + if low != low or high != high: + raise ValueError("make_tensor: one of low or high was NaN!") + if low > high: + raise ValueError("make_tensor: low must be weakly less than high!") + + low = clamp(low, lowest, highest) + high = clamp(high, lowest, highest) + + if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: + return math.floor(low), math.ceil(high) + + return low, high + + _integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + _floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64] + _complex_types = [torch.cfloat, torch.cdouble] + + if dtype is torch.bool: + result = torch.randint(0, 2, shape, device=device, dtype=dtype) + elif dtype is torch.uint8: + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = cast(Tuple[int, int], _modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype)) + result = torch.randint(low, high, shape, device=device, dtype=dtype) + elif dtype in _integral_types: + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype) + result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload] + elif dtype in _floating_types: + ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + rand_val = torch.rand(shape, device=device, dtype=dtype) + result = high * rand_val + low * (1 - rand_val) + elif dtype in _complex_types: + float_dtype = torch.float if dtype is torch.cfloat else torch.double + ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + real_rand_val = torch.rand(shape, device=device, dtype=float_dtype) + imag_rand_val = torch.rand(shape, device=device, dtype=float_dtype) + real = high * real_rand_val + low * (1 - real_rand_val) + imag = high * imag_rand_val + low * (1 - imag_rand_val) + result = torch.complex(real, imag) + else: + raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." + " To request support, file an issue at: https://github.com/pytorch/pytorch/issues") + + if noncontiguous and result.numel() > 1: + result = torch.repeat_interleave(result, 2, dim=-1) + result = result[..., ::2] + + if exclude_zero: + if dtype in _integral_types or dtype is torch.bool: + replace_with = torch.tensor(1, device=device, dtype=dtype) + elif dtype in _floating_types: + replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) + else: # dtype in _complex_types: + float_dtype = torch.float if dtype is torch.cfloat else torch.double + float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) + replace_with = torch.complex(float_eps, float_eps) + result[result == 0] = replace_with + + if dtype in _floating_types + _complex_types: + result.requires_grad = requires_grad + + return result diff --git a/torch/testing/_deprecated.py b/torch/testing/_deprecated.py index 7355aeea1a292..60c6384ad13cb 100644 --- a/torch/testing/_deprecated.py +++ b/torch/testing/_deprecated.py @@ -5,22 +5,32 @@ import functools import warnings -from typing import Any, Callable +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch +from . import _dtype_getters -__all__ = ["rand", "randn"] +__all__ = [ + "rand", + "randn", + "assert_allclose", +] -def warn_deprecated(instructions: str) -> Callable: + +def warn_deprecated(instructions: Union[str, Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], str]]) -> Callable: def outer_wrapper(fn: Callable) -> Callable: - msg = f"torch.testing.{fn.__name__} is deprecated and will be removed in the future. {instructions.strip()}" + name = fn.__name__ + head = f"torch.testing.{name}() is deprecated and will be removed in a future release. " @functools.wraps(fn) def inner_wrapper(*args: Any, **kwargs: Any) -> Any: + return_value = fn(*args, **kwargs) + tail = instructions(name, args, kwargs, return_value) if callable(instructions) else instructions + msg = (head + tail).strip() warnings.warn(msg, FutureWarning) - return fn(*args, **kwargs) + return return_value return inner_wrapper @@ -29,3 +39,65 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> Any: rand = warn_deprecated("Use torch.rand instead.")(torch.rand) randn = warn_deprecated("Use torch.randn instead.")(torch.randn) + + +_DTYPE_PRECISIONS = { + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-5), + torch.float64: (1e-5, 1e-8), +} + + +def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]: + actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0)) + expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0)) + return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) + + +# TODO: include the deprecation as soon as torch.testing.assert_close is stable +# @warn_deprecated( +# "Use torch.testing.assert_close instead. " +# "For detailed upgrade instructions see https://github.com/pytorch/pytorch/issues/61844." +# ) +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = _get_default_rtol_and_atol(actual, expected) + + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + check_is_coalesced=False, + msg=msg or None, + ) + + +def _dtype_getter_instructions(name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any], return_value: Any) -> str: + return f"This call to {name}(...) can be replaced with {return_value}." + + +# We iterate over all public dtype getters and expose them here with an added deprecation warning +for name in _dtype_getters.__all__: + if name.startswith("_"): + continue + fn = getattr(_dtype_getters, name) + + globals()[name] = warn_deprecated(_dtype_getter_instructions)(fn) + __all__.append(name) diff --git a/torch/testing/_dtype_getters.py b/torch/testing/_dtype_getters.py new file mode 100644 index 0000000000000..d16ca04f25778 --- /dev/null +++ b/torch/testing/_dtype_getters.py @@ -0,0 +1,138 @@ +"""This module exist to be able to deprecate the dtype getters publicly without doing so internally. The deprecated +public versions are defined in torch.testing._deprecated and exposed from torch.testing. The non-deprecated internal +versions should be imported from torch.testing._internal.dtype_getters +""" + +from typing import List + +import torch + +__all__ = [ + "_validate_dtypes", + "_dispatch_dtypes", + "all_types", + "all_types_and", + "all_types_and_complex", + "all_types_and_complex_and", + "all_types_and_half", + "complex_types", + "empty_types", + "floating_and_complex_types", + "floating_and_complex_types_and", + "floating_types", + "floating_types_and", + "double_types", + "floating_types_and_half", + "get_all_complex_dtypes", + "get_all_dtypes", + "get_all_fp_dtypes", + "get_all_int_dtypes", + "get_all_math_dtypes", + "integral_types", + "integral_types_and", +] + +# Functions and classes for describing the dtypes a function supports +# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros + +# Verifies each given dtype is a torch.dtype +def _validate_dtypes(*dtypes): + for dtype in dtypes: + assert isinstance(dtype, torch.dtype) + return dtypes + +# class for tuples corresponding to a PyTorch dispatch macro +class _dispatch_dtypes(tuple): + def __add__(self, other): + assert isinstance(other, tuple) + return _dispatch_dtypes(tuple.__add__(self, other)) + +_empty_types = _dispatch_dtypes(()) +def empty_types(): + return _empty_types + +_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) +def floating_types(): + return _floating_types + +_floating_types_and_half = _floating_types + (torch.half,) +def floating_types_and_half(): + return _floating_types_and_half + +def floating_types_and(*dtypes): + return _floating_types + _validate_dtypes(*dtypes) + +_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) +def floating_and_complex_types(): + return _floating_and_complex_types + +def floating_and_complex_types_and(*dtypes): + return _floating_and_complex_types + _validate_dtypes(*dtypes) + +_double_types = _dispatch_dtypes((torch.float64, torch.complex128)) +def double_types(): + return _double_types + +_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)) +def integral_types(): + return _integral_types + +def integral_types_and(*dtypes): + return _integral_types + _validate_dtypes(*dtypes) + +_all_types = _floating_types + _integral_types +def all_types(): + return _all_types + +def all_types_and(*dtypes): + return _all_types + _validate_dtypes(*dtypes) + +_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) +def complex_types(): + return _complex_types + +_all_types_and_complex = _all_types + _complex_types +def all_types_and_complex(): + return _all_types_and_complex + +def all_types_and_complex_and(*dtypes): + return _all_types_and_complex + _validate_dtypes(*dtypes) + +_all_types_and_half = _all_types + (torch.half,) +def all_types_and_half(): + return _all_types_and_half + +# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro + +def get_all_dtypes(include_half=True, + include_bfloat16=True, + include_bool=True, + include_complex=True, + include_complex32=False + ) -> List[torch.dtype]: + dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) + if include_bool: + dtypes.append(torch.bool) + if include_complex: + dtypes += get_all_complex_dtypes(include_complex32) + return dtypes + +def get_all_math_dtypes(device) -> List[torch.dtype]: + return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), + include_bfloat16=False) + get_all_complex_dtypes() + +def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: + return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] + + +def get_all_int_dtypes() -> List[torch.dtype]: + return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + + +def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]: + dtypes = [torch.float32, torch.float64] + if include_half: + dtypes.append(torch.float16) + if include_bfloat16: + dtypes.append(torch.bfloat16) + return dtypes diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 754ccca11ed9d..8350845e4ef19 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -307,7 +307,6 @@ def __init__(self, dev): ("conv1d", conv_args_fp32[0]), ("conv2d", conv_args_fp32[1]), ("conv3d", conv_args_fp32[2]), - ("log_softmax", pointwise0_fp32 + (0,)), ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), torch.randn((n, n, n), device=dev, dtype=torch.float32))), ("mm", mat0_fp32 + mat1_fp32), @@ -319,24 +318,22 @@ def __init__(self, dev): torch.randn((n, n, n), device=dev, dtype=torch.float32))), ] self.torch_fp32 = [ + ("conv_transpose1d", conv_args_bf16[0]), + ("conv_transpose2d", conv_args_bf16[1]), ("conv_transpose3d", conv_args_bf16[2]), ("batch_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32), "running_var": torch.rand((n), dtype=torch.float32), "training": False, "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}), - ("max_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}), ("dropout", dummy_bf16[2], {"p": 0.1, "train": False}), ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), - ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_bf16), - ("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)), - ("instance_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32), - "running_var": torch.rand((n), dtype=torch.float32), "use_input_stats": False, + ("instance_norm", dummy_bf16[1], {"weight": None, "bias": None, "running_mean": None, + "running_var": None, "use_input_stats": True, "momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}), ] self.nn_bf16 = [ ("linear", mat0_fp32 + mat1_fp32), ] self.nn_fp32 = [ - ("adaptive_avg_pool2d", dummy_bf16[2], {"output_size": (3, 2)}), ("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}), ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}), ("gelu", dummy_bf16[3]), @@ -348,9 +345,8 @@ def __init__(self, dev): ("upsample_trilinear3d", dummy_bf16[4], {"output_size": (n, n, n), "align_corners": False}), ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)), - ("smooth_l1_loss", mat0_bf16 + mat1_bf16), ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}), - ("std", dummy_bf16[2]), + ("smooth_l1_loss", mat0_bf16 + mat1_bf16), ] self.torch_need_autocast_promote = [ ("cat", (pointwise0_bf16 + pointwise1_fp32,)), diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 8ec6e71d121ff..23e431d66bec2 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -14,8 +14,7 @@ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH from torch.testing._internal.common_cuda import _get_torch_cuda_version -from torch.testing import \ - (get_all_dtypes) +from torch.testing._internal.common_dtype import get_all_dtypes try: import psutil # type: ignore[import] diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index fdad0ad0222fa..01e167f528af2 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -467,10 +467,6 @@ def _start_processes(self, proc) -> None: self.pid_to_pipe[process.pid] = parent_conn self.processes.append(process) - def _fork_processes(self) -> None: - proc = torch.multiprocessing.get_context("fork").Process - self._start_processes(proc) - def _spawn_processes(self) -> None: proc = torch.multiprocessing.get_context("spawn").Process self._start_processes(proc) @@ -526,10 +522,6 @@ def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: self.file_name = file_name self.run_test(test_name, parent_pipe, signal_send_pipe, event_listener_thread) - # exit to avoid run teardown() for fork processes - # Use os._exit() as it is the recommended way for child processes. - os._exit(0) - def run_test( self, test_name: str, parent_pipe, signal_pipe=None, event_listener_thread=None ) -> None: diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py new file mode 100644 index 0000000000000..0ce2d80a18b4a --- /dev/null +++ b/torch/testing/_internal/common_dtype.py @@ -0,0 +1,4 @@ +"""The content of torch/testing/_dtype_getters.py should be moved here as soon as the deprecation period is over. +""" + +from torch.testing._dtype_getters import * # noqa: F401, F403 diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py index 80cb4d0331889..89533a6d7fb9d 100644 --- a/torch/testing/_internal/common_jit.py +++ b/torch/testing/_internal/common_jit.py @@ -7,7 +7,7 @@ import torch.jit.quantized # Testing utils -from torch.testing import floating_and_complex_types_and +from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_utils import TestCase, \ freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 @@ -136,7 +136,7 @@ def get_recording_tensors(args): for g2, g2_test in zip(grads2, grads2_test): if g2 is None and g2_test is None: continue - self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4)) + self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4) class JitCommonTestCase(TestCase): def createFunctionFromGraph(self, trace): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b281c5e474c41..0db9bb508ee40 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6,22 +6,24 @@ import operator import random import numbers +import unittest +import os import torch import numpy as np from torch._six import inf import collections.abc -from typing import List, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, Dict -from torch.testing import \ - (make_non_contiguous, floating_types, floating_types_and, complex_types, - floating_and_complex_types, floating_and_complex_types_and, - all_types_and_complex_and, all_types_and, all_types_and_complex, - integral_types_and, all_types, double_types) -from .._core import _dispatch_dtypes +from torch.testing import make_non_contiguous, make_tensor +from torch.testing._internal.common_dtype import ( + _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, + floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, + all_types, double_types, +) from torch.testing._internal.common_device_type import \ - (expectedFailure, onlyOnCPUAndCUDA, skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver, + (onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol) from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater from torch.testing._internal.common_utils import \ @@ -31,7 +33,7 @@ random_symmetric_pd_matrix, make_symmetric_matrices, make_symmetric_pd_matrices, random_square_matrix_of_rank, random_fullrank_matrix_distinct_singular_value, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL,) import torch.testing._internal.opinfo_helper as opinfo_helper @@ -42,6 +44,15 @@ import scipy.special +# Reasonable testing sizes for dimensions +L = 20 +M = 10 +S = 5 + +# Unique value to distinguish default from anything else +_NOTHING = object() + + class DecorateInfo(object): """Describes which test, or type of tests, should be wrapped in the given decorators when testing an operator. Any test that matches all provided @@ -86,11 +97,12 @@ def __init__( active_if: whether tests matching the above arguments should be skipped expected_failure: whether to assert that skipped tests fail """ - decorator = expectedFailure(device_type) if expected_failure else skipIf(True, "Skipped!") + decorator = unittest.expectedFailure if expected_failure else unittest.skip("Skipped!") super().__init__(decorators=decorator, cls_name=cls_name, test_name=test_name, device_type=device_type, dtypes=dtypes, active_if=active_if) + class SampleInput(object): """Represents sample inputs to a function.""" @@ -184,6 +196,7 @@ def _np(t): sample_np_input, np_args, np_kwargs = to_numpy(self.input), to_numpy(self.args), to_numpy(self.kwargs) return (sample_np_input, np_args, np_kwargs) + class AliasInfo(object): """Class holds alias information. For example, torch.abs -> torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ @@ -199,9 +212,6 @@ def __call__(self, *args, **kwargs): return self.op(*args, **kwargs) -_NOTHING = object() # Unique value to distinguish default from anything else - - # Extension of getattr to support qualified names # e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm def _getattr_qual(obj, name, default=_NOTHING): @@ -769,9 +779,164 @@ def default_test_dtypes(self, device_type): else supported.intersection(self._default_test_dtypes)) -L = 20 -M = 10 -S = 5 +def _generate_reduction_inputs(device, dtype, requires_grad): + """Generates input tensors for testing reduction operators""" + yield make_tensor([], device, dtype, requires_grad=requires_grad) + yield make_tensor([2], device, dtype, requires_grad=requires_grad) + yield make_tensor([2, 3], device, dtype, requires_grad=requires_grad, noncontiguous=True) + yield make_tensor([3, 2, 1, 5], device, dtype, requires_grad=requires_grad) + + +def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): + """Generates a subset of all valid dim and keepdim kwargs given ndim that + is appropriate for testing reduction operators. + """ + + # Test default dim and keepdim + yield {} + + # Test reducing inner and outer most dimensions + yield {'dim': 0, 'keepdim': True} + yield {'dim': -1, 'keepdim': False} + + # Test reducing middle dimension + if ndim > 2: + yield {'dim': ndim // 2, 'keepdim': True} + + if supports_multiple_dims: + # Test reducing all dimensions + yield {'dim': tuple(range(ndim)), 'keepdim': False} + + # Test reducing both first and last dimensions + if ndim > 1: + yield {'dim': (0, -1), 'keepdim': True} + + # Test reducing every other dimension starting with the second + if ndim > 3: + yield {'dim': tuple(range(1, ndim, 2)), 'keepdim': False} + + +def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for reduction operators.""" + + # TODO(@heitorschueroff) Once all reduction operators are using + # ReductionOpInfo use op_info.supports_multiple_dims directly. + supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True) + + # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo + # use op_info.genearte_args_kwargs directly. + generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {})) + + inputs: List[SampleInput] = [] + for t in _generate_reduction_inputs(device, dtype, requires_grad): + for reduction_kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): + for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + inputs.append(SampleInput(t, args=args, kwargs=kwargs)) + + return inputs + + +# NOTE [Reductions]: +# +# For testing purposes, we relax the definition of a reduction operator +# as defined in the docstring below. We do this to capture operators with +# a similar API so they can be tested automatically. However... +# +# Strictly speaking a reduction operator is an operator that can reduce an +# array to a single scalar value and that can be computed from the partial +# result of reducing subarrays. This usually means that the reduction operation +# should be commutative and associative. This definition is important when it +# comes to implementation as it determines how a reduction can be parallelized. +# +# For example, many summary statistics such as median, mode and quantile cannot +# be computed from partial results because these are sorting and counting based +# algorithms that need information that would be lost in the reduced value. +class ReductionOpInfo(OpInfo): + """Reduction operator information. + + An operator is a reduction operator if it reduces one or more dimensions of + the input tensor to a single value. Reduction operators must implement the + following signature: + + - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` + + ReductionOpInfo tests that reduction operators implement a consistent API. + Optional features such as reducing over multiple dimensions are captured in + the optional keyword parameters of the ReductionOpInfo constructor. + + If a reduction operator does not yet implement the full required API of + reduction operators, this should be documented by skipping the failing + tests rather than adding optional parameters to ReductionOpInfo. + + NOTE + The API for reduction operators has not yet been finalized and some + requirements may change. + + See tests in test/test_reductions.py + """ + + def __init__( + self, name, *, + + # The identity value for the operator if it has one. + identity: Optional[Any] = None, + + # The nan policy for the operator if it implements one. + # - propagate: NaN values are propagated to the output + # - omit: NaN values are discarded during the reduction + nan_policy: Optional[str] = None, + + # Whether the operator supports reducing multiple dimensions. + supports_multiple_dims: bool = True, + + # Whether the operator promotes integral to floating point dtypes. + promotes_int_to_float: bool = False, + + # Whether the operator promotes all integral dtypes to int64. + promotes_int_to_int64: bool = False, + + # If a specific dtype is given, then the operator always returns that + # dtype irrespective of the input dtype. If None, the operator returns + # the dtype according to the type promotion rules above. + result_dtype: Optional[torch.dtype] = None, + + # ReductionOpInfo tests generate their own input, dim and keepdim + # arguments and call this function to generate tuples of extra args and + # kwargs to use when calling the op. This is required for operators that + # have other required parameters besides the input tensor. + generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (yield tuple(), {}), + + # Options from the OpInfo base class + **kwargs, + ): + assert nan_policy in (None, 'propagate', 'omit') + + # These are mutually exclusive options + assert not (result_dtype and promotes_int_to_float) + assert not (result_dtype and promotes_int_to_int64) + assert not (promotes_int_to_float and promotes_int_to_int64) + + # Default sample_inputs_func for ReductionOpInfo which augments sample + # inputs from sample_inputs_reduction with the args and kwargs from + # generate_args_kwargs. This is only used if sample_inputs_func is None. + def sample_inputs_func(*args, **kwargs): + kwargs['supports_multiple_dims'] = supports_multiple_dims + kwargs['generate_args_kwargs'] = generate_args_kwargs + return sample_inputs_reduction(*args, **kwargs) + + # Override OpInfo defaults and call base class __init__ + kwargs.setdefault('inplace_variant', None) + kwargs.setdefault('sample_inputs_func', sample_inputs_func) + super(ReductionOpInfo, self).__init__(name, **kwargs) + + self.identity = identity + self.nan_policy = nan_policy + self.supports_multiple_dims = supports_multiple_dims + self.promotes_int_to_float = promotes_int_to_float + self.promotes_int_to_int64 = promotes_int_to_int64 + self.result_dtype = result_dtype + self.generate_args_kwargs = generate_args_kwargs def sample_inputs_unary(op_info, device, dtype, requires_grad, **kwargs): @@ -858,6 +1023,7 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): (torch.tensor([1, 2, 3]),), (torch.tensor(1),), (torch.tensor([1, 2, 3]), 1), + (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1), # Cases with list of indices. ((2, 4),), ((2, 4), 1), @@ -1084,6 +1250,26 @@ def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad): dim=(0, 1)))) return inputs +def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input_shape, dict of dim and eps + cases: Tuple[tuple, dict] = ( # type: ignore[assignment] + ((S, S), {'dim': 1}), + ((S, 2), {'dim': -1}), + ((S,), {'dim': 0, 'eps': 0.5}), + ((), {'dim': 0}), + ((S, S, M), {'dim': 2}), + ((S, S), {}) + ) + + def generator(): + for input_shape, kwargs in cases: + yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) + # Test for Broadcasting + yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) + + return list(generator()) def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -1266,53 +1452,151 @@ def sample_inputs_linalg_vector_norm(op_info, device, dtype, requires_grad, **kw return inputs -# In order to use the kwarg alpha, partials should be used in an OpInfo's sample_inputs_func -# eg. sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2) -# Then one sample input would also be generated corresponding to the value of alpha provided. -# In the future, kwargs 'alpha_floating', 'alpha_integral' & 'alpha_complex' can be used to -# specify scalars of floating, integral & complex types as values for "alpha". -# Keyword argument `rhs_exclude_zero` is used to exclude zero values from rhs tensor argument -# This is necessary for operations like `true_divide`, where divide by zero throws an exception. -def sample_inputs_binary_pwise(op_info, device, dtype, requires_grad, extra_kwargs=None, **kwargs): - if extra_kwargs is None: - extra_kwargs = {} - - scalar = 3.14 + 3.14j if dtype.is_complex else (3.14 if dtype.is_floating_point else 3) - scalar = 1 if dtype is torch.bool else scalar - tests_list = [ - ((S, S, S), (S, S, S), False), - ((S, S, S), (S, S), False), - ((), (), False), - ((S, S, S), (), False), - ((S, S, S), scalar, False), - ((), scalar, False) - ] - tests_with_lhs_broadcasting = [ - ((S, S), (S, S, S), True), - ((), (S, S, S), True), - ((S, 1, S), (M, S), True), + +# Metadata class for binary "universal functions (ufuncs)" that accept two +# tensor and have common properties +class BinaryUfuncInfo(OpInfo): + """Operator information for 'universal binary functions (binary ufuncs).' + These are functions of two tensors with common properties like: + - they are elementwise functions + - the output shape is determined by the input shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/stable/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + def __init__(self, name, *, lhs_make_tensor_kwargs=None, rhs_make_tensor_kwargs=None, **kwargs): + super().__init__(name, **kwargs) + + # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = {} + self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs + + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = {} + self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs + + +def _resolve_binay_pwise_kwargs( + op_info, *, op_kwargs=None, lhs_make_tensor_kwargs=None, rhs_make_tensor_kwargs=None +): + """Resolves default values for :func:`sample_inputs_binary_pwise`. + + By default :attr:`op_kwargs`, :attr:`lhs_make_tensor_kwargs`, and :attr:`rhs_make_tensor_kwargs` are just empty + dictionaries. In case :attr:`op_info` is a :class:`BinaryUfuncInfo`, :attr:`BinaryUfuncInfo.lhs_make_tensor_kwargs` + and :attr:`BinaryUfuncInfo.rhs_make_tensor_kwargs` will be used as defaults. + """ + if op_kwargs is None: + op_kwargs = {} + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = op_info.lhs_make_tensor_kwargs if isinstance(op_info, BinaryUfuncInfo) else {} + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = op_info.rhs_make_tensor_kwargs if isinstance(op_info, BinaryUfuncInfo) else {} + + return op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs + + +def sample_inputs_binary_pwise( + op_info, + device, + dtype, + requires_grad, + *, + python_scalars=False, + op_kwargs=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + **kwargs, +): + op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs = _resolve_binay_pwise_kwargs( + op_info, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + ) + + scalar = make_tensor((), device=device, dtype=dtype, **rhs_make_tensor_kwargs) + if python_scalars: + scalar = scalar.item() # type: ignore[assignment] + + shapes = [ + ((), scalar), + ((S,), scalar), + ((S, 1), (S,)), + ((M, S), scalar), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), ] - test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] - samples = [] - for first_shape, shape_or_scalar, broadcasts_input in test_cases: - arg = shape_or_scalar - - if isinstance(shape_or_scalar, tuple): - exclude_zero = kwargs.get('rhs_exclude_zero', False) - arg = make_tensor(shape_or_scalar, device=device, dtype=dtype, - requires_grad=requires_grad, exclude_zero=exclude_zero) - samples.append(SampleInput(make_tensor(first_shape, device=device, dtype=dtype, - requires_grad=requires_grad), - args=(arg,), kwargs=extra_kwargs, - broadcasts_input=broadcasts_input)) - # Adds an extra sample using "alpha" if it's passed in kwargs - if 'alpha' in kwargs: - a = make_tensor((S, S, S), device=device, dtype=dtype, requires_grad=requires_grad) - b = make_tensor((S, S, S), device=device, dtype=dtype, requires_grad=requires_grad) - extra_kwargs['alpha'] = kwargs['alpha'] - sample = SampleInput(a, args=(b,), kwargs=extra_kwargs) - samples.append(sample) - return tuple(samples) + + sample_inputs = [] + for shape_lhs, shape_rhs_or_scalar in shapes: + lhs = make_tensor( + shape_lhs, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **lhs_make_tensor_kwargs, + ) + if isinstance(shape_rhs_or_scalar, tuple): + # shape + rhs = make_tensor( + shape_rhs_or_scalar, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **rhs_make_tensor_kwargs, + ) + broadcasts_input = torch.broadcast_shapes(shape_lhs, shape_rhs_or_scalar) != shape_lhs + else: + # scalar + rhs = shape_rhs_or_scalar # type: ignore[assignment] + broadcasts_input = False + + sample_inputs.append(SampleInput(lhs, args=(rhs,), kwargs=op_kwargs, broadcasts_input=broadcasts_input)) + return sample_inputs + + +def sample_inputs_add_sub( + op_info, + device, + dtype, + requires_grad, + python_scalars=False, + alpha=1, + op_kwargs=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + **kwargs, +): + op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs = _resolve_binay_pwise_kwargs( + op_info, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + ) + + sample_inputs = sample_inputs_binary_pwise( + op_info, + device, + dtype, + requires_grad, + python_scalars=python_scalars, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + **kwargs, + ) + + lhs = make_tensor((S, S), device=device, dtype=dtype, requires_grad=requires_grad, **lhs_make_tensor_kwargs) + rhs = make_tensor((S, S), device=device, dtype=dtype, requires_grad=requires_grad, **rhs_make_tensor_kwargs) + sample_inputs.append(SampleInput(lhs, args=(rhs,), kwargs=dict(op_kwargs, alpha=alpha), broadcasts_input=False)) + + return sample_inputs def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): @@ -1323,15 +1607,29 @@ def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): - args_list = ( - ((S, M), (M, S)), - ) - inputs = tuple(SampleInput(make_tensor(first_shape, device, dtype, - requires_grad=requires_grad), - args=(make_tensor(second_shape, device, dtype, - requires_grad=requires_grad),)) - for first_shape, second_shape in args_list) - return inputs + first_shape, second_shape = (S, M), (M, S) + sample_inputs = [] + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad), + args=(make_tensor(second_shape, device, dtype, + requires_grad=requires_grad),))) + + if dtype.is_complex: + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad), + args=( + make_tensor(second_shape, device, dtype, + requires_grad=requires_grad).conj(),))) + + sample_inputs.append( + SampleInput(make_tensor(first_shape, device, dtype, + requires_grad=requires_grad).transpose(0, 1), + args=( + make_tensor(second_shape, device, dtype, + requires_grad=requires_grad).transpose(0, 1).conj(),))) + return sample_inputs def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) @@ -1344,15 +1642,40 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): ((), (2, 2), (2, 3), True) ] test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] - inputs = tuple(SampleInput(make_tensor(shape_a, device, dtype, requires_grad=requires_grad), - args=(make_tensor(shape_b, device, dtype, - requires_grad=requires_grad), - make_tensor(shape_c, device, dtype, - requires_grad=requires_grad)), - kwargs={'alpha': alpha_val, 'beta': beta_val}, - broadcasts_input=broadcasts_input) - for shape_a, shape_b, shape_c, broadcasts_input in test_cases) - return inputs + + sample_inputs = [] + + for shape_a, shape_b, shape_c, broadcasts_input in test_cases: + sample_inputs.append( + SampleInput( + make_tensor(shape_a, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape_b, device, dtype, + requires_grad=requires_grad), + make_tensor(shape_c, device, dtype, + requires_grad=requires_grad)), + kwargs={'alpha': alpha_val, 'beta': beta_val}, + broadcasts_input=broadcasts_input)) + + if dtype.is_complex: + shape = (3, 3) + sample_inputs.append( + SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape, device, dtype, + requires_grad=requires_grad).t().conj(), + make_tensor(shape, device, dtype, + requires_grad=requires_grad)), + kwargs={'alpha': alpha_val, 'beta': beta_val},)) + sample_inputs.append( + SampleInput(make_tensor(shape, device, dtype, requires_grad=requires_grad), + args=( + make_tensor(shape, device, dtype, + requires_grad=requires_grad), + make_tensor(shape, device, dtype, + requires_grad=requires_grad).t().conj()), + kwargs={'alpha': alpha_val, 'beta': beta_val},)) + return sample_inputs def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): return ( @@ -1375,14 +1698,24 @@ def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): ) def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): - return ( - SampleInput( + sample_inputs = [] + sample_inputs.append(SampleInput( + make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), + args=( + make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), + ) + )) + if dtype.is_complex: + # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor) + # is tested in test_conj_view (which tests operations with only conjugated input tensor + # -- not conjugated arg tensors) + sample_inputs.append(SampleInput( make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), args=( - make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad), + torch.conj(make_tensor((S, ), device, dtype, low=None, high=None, requires_grad=requires_grad)), ) - ), - ) + )) + return sample_inputs def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -1474,6 +1807,23 @@ def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): sample_inputs.append(SampleInput(args[0], args=(args[1], args[2]), kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)), broadcasts_input=broadcasts_input)) + + if dtype.is_complex: + shapes = [(S, S, S), (S, M, S), (S, S, M)] + args = (make_tensor(shapes[0], device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor(shapes[1], device, dtype, + low=None, high=None, + requires_grad=requires_grad), + make_tensor(shapes[2], device, dtype, + low=None, high=None, + requires_grad=requires_grad)) + sample_inputs.append( + SampleInput( + args[0].transpose(-1, 1), args=(args[1].transpose(-1, 1).conj(), args[2].transpose(-1, 1).conj()), + kwargs=dict(beta=beta * (1 + 2j), alpha=alpha * (2 + 3j)),)) + return tuple(sample_inputs) def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): @@ -1830,6 +2180,25 @@ def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs): return (SampleInput(tensors, args=(0,)),) +def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + cases: Tuple[tuple, tuple, dict] = ( # type: ignore[assignment] + ((S, S), (S, S), {'dim': -1}), + ((S, S), (S, S), {'dim': 1}), + ((M, S), (S, S), {'dim': 0}), # different shapes + ((1, 2, 3), (1, 2, 3), {'dim': -2}), + ((0,), (0,), {'dim': 0}), # empty tensor + ((0, S), (S, S), {'dim': 0}), + ((1,), (1,), {}) # dim not passed, fallback to default + ) + + def generator(): + for input_shape1, input_shape2, kwargs in cases: + yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) + + return list(generator()) + def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs): tensors = [ make_tensor((S, S), device, dtype, requires_grad=requires_grad), @@ -1897,28 +2266,6 @@ def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs ) -def sample_inputs_amax_amin(op_info, device, dtype, requires_grad, **kwargs): - # Ordered as (input shape, kwargs) - test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment] - ((S, S, S), {}), - ((S, S, S), {'dim': 1}), - ((S, S, S), {'dim': (1, 2,)}), - ((S, S, S), {'dim': 1, 'keepdim': True}), - ((), {'dim': 0}), - ((), {}), - ((), {'dim': 0, 'keepdim': True}), - ) - - samples: List[SampleInput] = [] - for shape, kwargs in test_cases: - samples.append(SampleInput( - make_tensor(shape, device, dtype, requires_grad=requires_grad), - kwargs=kwargs)) - - return samples - -# TODO (@heitorschueroff) Once aminmax supports multiple dims this should -# be combined with the above test. def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment] ((S, S, S), {}), @@ -1937,33 +2284,6 @@ def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): return samples -def sample_inputs_argmax_argmin(op_info, device, dtype, requires_grad, **kwargs): - test_cases = ( - ((2, 2, 2), ()), - ((2, 2, 2), (0,)), - ((2, 2, 2), (1,)), - ((2, 2, 2), (2,)), - ((2, 2, 2), (2, True,)), - ((2, 2, 2), (None,)), - ((), (0,)), - ((), ()), - ((), (None, True,)), - ((1,), ()), - ((1,), (0,)), - ((1,), (0, True)), - ((2,), ()), - ((2,), (0,)), - ((2,), (0, True)), - ((2, 2, 3), ()), - ((2, 2, 3), (0,)), - ((2, 2, 3), (1,)), - ((2, 2, 3), (None, True)), - ) - return tuple(SampleInput((make_tensor(size, device, dtype, - requires_grad=requires_grad)), - args=args) - for size, args in test_cases) - def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): test_cases = ( ((1,), 0, None, None), @@ -2304,6 +2624,42 @@ def generator(): return list(generator()) +def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, normalized_shape and a kwarg dict for eps + cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] + ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), + ((2, 2, 3), (2, 3), {'eps': -0.5}), + ((1,), (1,), {}), + ((1, 2), (2,), {}), + ((0, 1), (1,), {}), + ) + + def generator(): + for input_shape, normalized_shape, kwargs in cases: + # Shape of weight and bias should be the same as normalized_shape + weight = make_arg(normalized_shape) + bias = make_arg(normalized_shape) + yield SampleInput( + make_arg(input_shape), + args=(normalized_shape, weight, bias), + kwargs=kwargs + ) + # Without any optional args + yield SampleInput(make_arg((1, 2)), args=((2,),)) + + # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs, + # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400 + + # With weight and a `None` bias + # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None)) + + # With `None` weight and bias (tests failing for this, see the link above) + # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) + + return list(generator()) + def sample_inputs_hardswish(self, device, dtype, requires_grad): N = 5 # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? @@ -2311,6 +2667,48 @@ def sample_inputs_hardswish(self, device, dtype, requires_grad): requires_grad=requires_grad, low=-5, high=5)) for _ in range(1, N)] return tensors +def sample_inputs_interpolate(mode, self, device, dtype, requires_grad): + N, C = 2, 3 + D = 4 + S = 3 + L = 5 + + align_corners_options: Tuple[Any, ...] = (None,) + if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): + align_corners_options = (True, False, None) + ranks_for_mode = { + 'nearest': [1, 2, 3], + 'linear': [1], + 'bilinear': [2], + 'bicubic': [2], + 'trilinear': [3], + 'area': [1, 2, 3] + } + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = partial(make_tensor, device=device, dtype=dtype, + requires_grad=requires_grad, low=-1, high=1) + + sample_inputs = [] + for align_corners in align_corners_options: + for rank in ranks_for_mode[mode]: + sample_inputs.extend([ + SampleInput(make_arg(shape(D, rank)), + args=(shape(S, rank, False), None, mode, align_corners)), + SampleInput(make_arg(shape(D, rank)), + args=(shape(L, rank, False), None, mode, align_corners)), + SampleInput(make_arg(shape(D, rank)), + args=(None, 1.7, mode, align_corners)), + SampleInput(make_arg(shape(D, rank)), + args=(None, 0.6, mode, align_corners)), + ]) + + return sample_inputs + def sample_inputs_gelu(self, device, dtype, requires_grad): N = 5 tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype, @@ -2342,56 +2740,6 @@ def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad requires_grad=requires_grad),)) return inputs -# Generates input tensors for testing reduction ops -def _generate_reduction_inputs(device, dtype, requires_grad): - yield make_tensor((), device, dtype, requires_grad=requires_grad) - yield make_tensor((2,), device, dtype, requires_grad=requires_grad) - yield make_tensor((2, 3), device, dtype, requires_grad=requires_grad, noncontiguous=True) - yield make_tensor((3, 2, 1, 2, 2), device, dtype, requires_grad=requires_grad) - -# Generates a subset of possible dim and keepdim kwargs for a tensor -# with ndim dims appropriate for testing. If supports_multiple_dims -# is True (default) then dim kwarg can be a list of dims. -def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): - for keepdim in [True, False]: - # Always test reducing inner and outer most dimensions - yield {'dim': 0, 'keepdim': keepdim} - yield {'dim': -1, 'keepdim': keepdim} - - # Also reduce middle dimension - if ndim > 2: - yield {'dim': ndim // 2, 'keepdim': keepdim} - - if supports_multiple_dims: - # Always test reducing all dims - yield {'dim': tuple(range(ndim)), 'keepdim': keepdim} - - # Test reducing both first and last dimensions - if ndim > 1: - yield {'dim': (0, ndim - 1), 'keepdim': keepdim} - - # Test reducing every other dimension starting with the second - if ndim > 3: - yield {'dim': tuple(range(1, ndim, 2)), 'keepdim': keepdim} - -# Wraps sample_inputs_reduction function to provide the additional supports_multiple_dims args -def sample_inputs_reduction_wrapper(supports_multiple_dims): - # Generates sample inputs for reduction ops that contain the input tensor - # and dim and keepdim kwargs. If a reduction op needs to test additional - # args/kwargs then create a separate sample_inputs function - def fn(op_info, device, dtype, requires_grad): - inputs = [] - - for t in _generate_reduction_inputs(device, dtype, requires_grad): - # Add case without dim and keepdim kwargs - inputs.append(SampleInput(t)) - for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): - inputs.append(SampleInput(t, kwargs=kwargs)) - - return inputs - - return fn - def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad): test_quantiles = (0.5, make_tensor((2,), device, dtype, low=0, high=1)) test_interpolations = ['linear', 'midpoint'] @@ -2403,12 +2751,22 @@ def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad): inputs.append(SampleInput(t, args=(quantiles,))) for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): # Interpolation kwarg for now is only supported when providing both dim and keepdim + kwargs.setdefault('dim', 0) + kwargs.setdefault('keepdim', False) for interpolation in test_interpolations: kwargs['interpolation'] = interpolation inputs.append(SampleInput(t, args=(quantiles,), kwargs=kwargs)) return inputs +def sample_inputs_reduction_count_nonzero(*args, **kwargs): + """Sample inputs for count_nonzero""" + samples: List[SampleInput] = sample_inputs_reduction(*args, **kwargs) + # count_nonzero does not support keepdim yet + for sample in samples: + sample.kwargs.pop('keepdim', None) + return samples + def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad): N = 10 tensors = [SampleInput(make_tensor((N, N), device=device, dtype=dtype, @@ -4033,19 +4391,6 @@ def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): return samples -def sample_inputs_floor_divide(op_info, device, dtype, requires_grad, **kwargs): - lhs = make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad) - rhs = make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad) - # Avoid integer divide by 0 - if not (dtype.is_floating_point or dtype.is_complex): - rhs[rhs == 0] = 1 - - return [ - SampleInput(lhs, args=(rhs,)), - SampleInput(lhs, args=(rhs[0],)), - SampleInput(lhs, args=(3.14,)), - ] - def sample_inputs_isin(op_info, device, dtype, requires_grad): element = make_tensor((L,), device, dtype, low=None, high=None, requires_grad=requires_grad) indices = torch.randint(0, L, size=[S]) @@ -4986,6 +5331,22 @@ def sample_inputs_softplus(op_info, device, dtype, requires_grad, **kwargs): SampleInput(make_input(low=1), kwargs=dict(threshold=1)), ] +def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs): + def make_input(): + input = make_fullrank_matrices_with_distinct_singular_values(12, 12, device=device, dtype=dtype) + return input.requires_grad_(requires_grad) + + # lhs / rhs shape can have any number of dimensions as long as their product equals 12 + shapes = [ + ((2, 2, 3), (12, 1)), + ((4, 3), (6, 1, 2)), + ] + + return [ + SampleInput(make_input().reshape(*shape_lhs, *shape_rhs), kwargs=dict(ind=len(shape_lhs))) + for shape_lhs, shape_rhs in shapes + ] + def sample_inputs_mse_loss(op_info, device, dtype, requires_grad, **kwargs): _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -5034,6 +5395,36 @@ def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): return sample_inputs +def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): + batch_size, num_classes = shape = (2, 3) + + input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [ + ((*shape, 1), dict()), + ((*shape, 1, 2), dict()), + ((*shape, 1, 2, 3), dict()), + (shape, dict(weight=make_tensor((num_classes,), device=device, dtype=dtype).abs())), + (shape, dict(ignore_index=num_classes // 2)), + (shape, dict(reduction="sum")), + (shape, dict(reduction="mean")), + ] + + sample_inputs = [] + for input_shape, kwargs in input_shape_and_kwargs: + input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) + + target = make_tensor( + (batch_size, *input_shape[2:]), + low=0, + high=num_classes, + device=device, + dtype=torch.long, + requires_grad=requires_grad + ) + + sample_inputs.append(SampleInput(input, args=(target,), kwargs=kwargs)) + + return sample_inputs + foreach_unary_op_db: List[OpInfo] = [ ForeachFuncInfo('exp'), ForeachFuncInfo('acos'), @@ -5316,6 +5707,21 @@ def reference_mse_loss(input, target, reduction="mean"): return se +def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): + feature_size = np.prod(normalized_shape) + inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] + mean = inp_view.mean(axis=-1, keepdims=True) + var = inp_view.var(axis=-1, ddof=0, keepdims=True) + Y = (inp_view - mean) / np.sqrt(var + eps) + if weight is None and bias is not None: + Y = Y + bias.reshape(-1) + elif weight is not None and bias is None: + Y = Y * weight.reshape(-1) + elif weight is not None and bias is not None: + Y = Y * weight.reshape(-1) + bias.reshape(-1) + return Y.reshape(*inp.shape) + + def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): """Gradcheck wrapper for functions that take Hermitian matrices as input. @@ -5424,29 +5830,29 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=[torch.cdouble]), )), - OpInfo('add', - # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate - ref=lambda input, other, *, alpha=1: np.add(input, np.multiply(alpha, other)), - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), - assert_autodiffed=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2), - supports_inplace_autograd=False, - supports_forward_ad=True), - OpInfo('mul', - aliases=('multiply',), - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - assert_autodiffed=True, - supports_forward_ad=True, - sample_inputs_func=sample_inputs_binary_pwise), - OpInfo('sub', - # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate - ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), - aliases=('subtract',), - dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), - assert_autodiffed=True, - supports_forward_ad=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2), - supports_inplace_autograd=False), + BinaryUfuncInfo('add', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.add(input, np.multiply(alpha, other)), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + assert_autodiffed=True, + sample_inputs_func=partial(sample_inputs_add_sub, alpha=2), + supports_inplace_autograd=False, + supports_forward_ad=True), + BinaryUfuncInfo('mul', + aliases=('multiply',), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + assert_autodiffed=True, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_binary_pwise), + BinaryUfuncInfo('sub', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), + aliases=('subtract',), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + assert_autodiffed=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_add_sub, alpha=2), + supports_inplace_autograd=False), OpInfo('addmm', # This addmm OpInfo is for when alpha and beta are not both equal to 1. # alpha=beta=1 is tested in the following opinfo, because that special case will @@ -5517,6 +5923,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): *[torch.bfloat16] if SM53OrLater else [], torch.complex64, torch.complex128), supports_forward_ad=True, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view', device_type='cuda')], skips=( # FIXME: bfloat16 backward support likely depends on CUDA11+ # and SM53+ @@ -5526,14 +5939,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ), sample_inputs_func=sample_inputs_baddbmm), OpInfo('dot', - dtypes=all_types_and_complex_and(torch.float16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), assert_autodiffed=True, sample_inputs_func=sample_inputs_dot_vdot, supports_forward_ad=True, ), OpInfo('vdot', - dtypes=all_types_and_complex_and(torch.float16), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), sample_inputs_func=sample_inputs_dot_vdot, supports_forward_ad=True, @@ -5576,6 +5989,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('addcmul', dtypes=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, @@ -5586,6 +6000,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): sample_inputs_func=sample_inputs_addcmul_addcdiv), OpInfo('addcdiv', dtypes=floating_and_complex_types(), + dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_inplace_autograd=False, supports_forward_ad=True, @@ -5593,22 +6008,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # TODO: update sample inputs with for_inplace_variant kwarg to support this test SkipInfo('TestCommon', 'test_variant_consistency_eager'),), sample_inputs_func=sample_inputs_addcmul_addcdiv), - OpInfo('amax', - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs), - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_amax_amin,), - OpInfo('amin', - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs), - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_amax_amin), - OpInfo('argmax', - dtypes=all_types_and(torch.float16, torch.bfloat16), - supports_autograd=False, - sample_inputs_func=sample_inputs_argmax_argmin,), - OpInfo('argmin', - dtypes=all_types_and(torch.float16, torch.bfloat16), - supports_autograd=False, - sample_inputs_func=sample_inputs_argmax_argmin,), UnaryUfuncInfo('asin', aliases=('arcsin', ), ref=np.arcsin, @@ -5915,6 +6314,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('cosh', ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), safe_casts_outputs=True, assert_autodiffed=True, @@ -6001,41 +6401,43 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_forward_ad=True, sample_inputs_func=sample_inputs_diff), - OpInfo('div', - aliases=('divide',), - variant_test_name='no_rounding_mode', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, rhs_exclude_zero=True), - supports_forward_ad=True, - assert_autodiffed=True), - OpInfo('div', - aliases=('divide',), - variant_test_name='trunc_rounding', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, extra_kwargs={ - "rounding_mode": 'trunc'}, rhs_exclude_zero=True), - supports_forward_ad=True, - skips=( - # Reference: https://github.com/pytorch/pytorch/issues/59174 - SkipInfo('TestJit', 'test_variant_consistency_jit'), - ), - assert_autodiffed=True), - OpInfo('div', - aliases=('divide',), - variant_test_name='floor_rounding', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, extra_kwargs={ - "rounding_mode": 'floor'}, rhs_exclude_zero=True), - supports_forward_ad=True, - skips=( - # Reference: https://github.com/pytorch/pytorch/issues/59174 - SkipInfo('TestJit', 'test_variant_consistency_jit'), - ), - assert_autodiffed=True), - OpInfo('true_divide', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - supports_forward_ad=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, rhs_exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='no_rounding_mode', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_binary_pwise, + supports_forward_ad=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='trunc_rounding', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="trunc"), + supports_forward_ad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/59174 + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='floor_rounding', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="floor"), + supports_forward_ad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/59174 + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('true_divide', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + sample_inputs_func=sample_inputs_binary_pwise, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), UnaryUfuncInfo('exp', ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), @@ -6288,15 +6690,17 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16), safe_casts_outputs=True), - OpInfo('floor_divide', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_floor_divide, - supports_autograd=False, - ), + BinaryUfuncInfo('floor_divide', + dtypes=all_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_binary_pwise, + supports_autograd=False, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + ), UnaryUfuncInfo('frexp', op=torch.frexp, ref=np.frexp, dtypes=floating_types_and(torch.half), + dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), # skip testing torch.frexp as it is not supported by ROCm platform yet decorators=[skipCUDAIfRocm], supports_out=False, @@ -6712,8 +7116,9 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_out=False, ), OpInfo('matmul', + aliases=('linalg.matmul',), dtypes=floating_types(), - dtypesIfCPU=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []), dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, @@ -6724,7 +7129,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): skips=( # matmul does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'), - SkipInfo('TestCommon', 'test_conj_view', device_type='cpu'), )), OpInfo('max', op=torch.max, @@ -6755,19 +7159,19 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of median do support out supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(False)), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('nanmedian', dtypes=all_types(), dtypesIfCPU=all_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of nanmedian do support out supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(False)), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('var_mean', dtypes=floating_and_complex_types_and(torch.half), dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_reduction_wrapper(False), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), backward_dtypes=floating_types_and(torch.half), backward_dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.half), @@ -6786,7 +7190,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=floating_and_complex_types_and(torch.half), dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_reduction_wrapper(False), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), backward_dtypes=floating_types_and(torch.half), backward_dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.half), @@ -6861,21 +7265,12 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_out=False, supports_forward_ad=True, sample_inputs_func=sample_inputs_max_min_reduction_no_dim,), - OpInfo('sum', - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - supports_forward_ad=True, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True)), - OpInfo('nansum', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True)), # TODO(@heitorschueroff) Add test for dtype kwarg OpInfo('mean', dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True), + sample_inputs_func=sample_inputs_reduction, # Need to skip out test because one of the overload for mean does not support it # TODO(@heitorschueroff) fix this when implementing ReductionInfo skips=(SkipInfo('TestCommon', 'test_out'),)), @@ -6934,6 +7329,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # FIXME: aminmax does not check for safe casting to output SkipInfo('TestCommon', 'test_out'), )), + OpInfo('nn.functional.cosine_similarity', + aten_name="cosine_similarity", + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_cosine_similarity), OpInfo('nn.functional.adaptive_avg_pool2d', dtypes=floating_types(), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), @@ -6967,6 +7369,21 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): SkipInfo('TestJit', 'test_variant_consistency_jit'), ), supports_out=False,), + OpInfo('nn.functional.layer_norm', + aten_name='layer_norm', + aliases=('layer_norm',), + ref=reference_layer_norm, + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), + 'TestCommon', 'test_reference_testing' + ), + unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""), "This test makes TBB Sad"), + ], + sample_inputs_func=sample_inputs_layer_norm,), OpInfo('nn.functional.pad', variant_test_name='constant', aten_name='constant_pad_nd', @@ -7026,8 +7443,81 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('nn.functional.unfold', aten_name='im2col', dtypes=floating_types_and(torch.half), + dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_unfold, skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='nearest', + supports_autograd=True, + dtypesIfCPU=floating_types_and(torch.uint8), + dtypesIfCUDA=floating_types_and(torch.half, torch.uint8), + sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'), + skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='linear', + supports_autograd=True, + dtypesIfCUDA=floating_types_and(torch.half), + sample_inputs_func=partial(sample_inputs_interpolate, 'linear'), + skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bilinear', + supports_autograd=True, + dtypesIfCUDA=floating_types_and(torch.half), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), + skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='bicubic', + supports_autograd=True, + dtypesIfCUDA=floating_types_and(torch.half), + sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='trilinear', + supports_autograd=True, + dtypesIfCUDA=floating_types_and(torch.half), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'), + skips=( + # JIT alias info internal asserts here + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + supports_out=False), + OpInfo('nn.functional.interpolate', + aten_name="interpolate", + variant_test_name='area', + supports_autograd=True, + dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_interpolate, 'area'), + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + skips=( + # JIT alias info internal asserts here SkipInfo('TestJit', 'test_variant_consistency_jit'), ), supports_out=False), @@ -7036,6 +7526,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): aten_name="leaky_relu", dtypes=floating_types(), sample_inputs_func=sample_inputs_leaky_relu, + dtypesIfCPU=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_autograd=True, assert_autodiffed=True, @@ -7227,16 +7718,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_forward_ad=True, skips=( SkipInfo('TestMathBits', 'test_conj_view', device_type='cuda'),),), - OpInfo('prod', - dtypes=all_types_and_complex_and(torch.bool), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), - skips=( - # prod does not support the (Tensor, *, out) overload - SkipInfo('TestCommon', 'test_out', - dtypes=[torch.float32]), - ), - sample_inputs_func=sample_inputs_prod, - gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('qr', op=torch.qr, dtypes=floating_and_complex_types(), @@ -7316,6 +7797,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('sinh', ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), safe_casts_outputs=True, assert_autodiffed=True, @@ -7428,7 +7910,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('__rmatmul__', op=torch.Tensor.__rmatmul__, dtypes=floating_types(), - dtypesIfCPU=all_types_and_complex(), + dtypesIfCPU=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else [], torch.complex64, torch.complex128), backward_dtypesIfCUDA=floating_types_and(torch.float16, @@ -7437,6 +7919,10 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): assert_autodiffed=True, sample_inputs_func=sample_inputs_matmul, supports_out=False, + decorators=[ + DecorateInfo( + toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), + 'TestMathBits', 'test_conj_view')], skips=( SkipInfo('TestJit', 'test_variant_consistency_jit',), )), @@ -7557,7 +8043,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), # "tanh_backward_cpu" not implemented for 'BFloat16' - backward_dtypesIfCPU=all_types_and_complex_and(torch.bool), + backward_dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16), assert_autodiffed=True, safe_casts_outputs=True, supports_forward_ad=True, @@ -7573,6 +8059,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): active_if=(IS_MACOS or IS_WINDOWS)), )), OpInfo('tensor_split', + ref=np.array_split, dtypes=all_types_and_complex_and(torch.bool), dtypesIfCPU=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), @@ -7636,6 +8123,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('nan_to_num', ref=np.nan_to_num, dtypes=all_types_and(torch.half, torch.bool), + dtypesIfCPU=all_types_and(torch.half, torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16), supports_forward_ad=True, # Passing numpy_kwargs via sample_kwargs, as numpy does comparison @@ -8114,17 +8602,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('stack', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_stack, - assert_autodiffed=True, - skips=( - # stack does not correctly warn when resizing out= inputs - SkipInfo('TestCommon', 'test_out'),),), + assert_autodiffed=True), OpInfo('hstack', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_hstack_dstack_vstack, - supports_forward_ad=True, - skips=( - # hstack does not correctly warn when resizing out= inputs - SkipInfo('TestCommon', 'test_out'),),), + supports_forward_ad=True), OpInfo('hypot', dtypes=floating_types(), dtypesIfCPU=floating_types_and(torch.bfloat16), @@ -8141,24 +8623,31 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # JIT tests don't work with Tensor keyword arguments # https://github.com/pytorch/pytorch/issues/58507 SkipInfo('TestJit', 'test_variant_consistency_jit'),),), + OpInfo('cat', + ref=lambda input_seq, dim=0, **kwargs: np.concatenate(input_seq, axis=dim, **kwargs), + aliases=('concat',), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_cat_concat, + supports_forward_ad=True, + assert_autodiffed=True, + skips=( + # RuntimeError: Arguments for call not valid. + # Expected a value of type 'List[Tensor]' for argument + # 'tensors' but instead found type 'Tensor (inferred)'. + SkipInfo('TestJit', 'test_jit_alias_remapping'),)), OpInfo('vstack', aliases=('row_stack',), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_hstack_dstack_vstack, supports_forward_ad=True, skips=( - # vstack does not correctly warn when resizing out= inputs - SkipInfo('TestCommon', 'test_out'), # RuntimeError: _fn() Expected a value of type # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'. - SkipInfo('TestJit', 'test_jit_alias_remapping'))), + SkipInfo('TestJit', 'test_jit_alias_remapping'),)), OpInfo('dstack', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_hstack_dstack_vstack, - supports_forward_ad=True, - skips=( - # dstack does not correctly warn when resizing out= inputs - SkipInfo('TestCommon', 'test_out'),)), + supports_forward_ad=True), OpInfo('unfold', op=lambda x, *args: x.unfold(*args), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -8673,6 +9162,19 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ), ), ), + OpInfo( + "linalg.tensorinv", + ref=np.linalg.tensorinv, + dtypes=floating_and_complex_types(), + skips=( + # RuntimeError: aliasOp != torch::jit::getOperatorAliasMap().end() + # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":159, + # please report a bug to PyTorch. + SkipInfo('TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + sample_inputs_func=sample_inputs_tensorinv, + supports_forward_ad=True, + ), OpInfo( "nn.functional.mse_loss", ref=reference_mse_loss, @@ -8706,13 +9208,192 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): ), ), ), + ReductionOpInfo( + 'all', + identity=True, + supports_multiple_dims=False, + supports_out=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: does not support passing keepdim without dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: does not support dim=None + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + # FIXME: uint8 input returns uint8 instead of bool + SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'any', + identity=False, + supports_multiple_dims=False, + supports_out=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: does not support passing keepdim without dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: does not support dim=None + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + # FIXME: uint8 input returns uint8 instead of bool + SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'amax', + nan_policy='propagate', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs), + skips=( + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'amin', + nan_policy='propagate', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs), + skips=( + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'argmax', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + # FIXME: keepdim parameter is ignored when dim=None + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'argmin', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + # FIXME: keepdim parameter is ignored when dim=None + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'count_nonzero', + identity=0, + supports_out=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_reduction_count_nonzero, + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + SkipInfo('TestReductions', 'test_dim_single_keepdim'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + SkipInfo('TestReductions', 'test_dim_multi_keepdim'), + SkipInfo('TestReductions', 'test_dim_multi_unsorted_keepdim'), + SkipInfo('TestReductions', 'test_dim_offbounds_keepdim'), + # FIXME: dim=[] reduces all dimensions + SkipInfo('TestReductions', 'test_dim_empty'), + ), + ), + ReductionOpInfo( + 'prod', + identity=1, + nan_policy='propagate', + supports_multiple_dims=False, + supports_out=False, + promotes_int_to_int64=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_prod, + skips=( + # FIXME: prod does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'sum', + identity=0, + nan_policy='propagate', + supports_out=False, + supports_forward_ad=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: sum does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: sum does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'nansum', + identity=0, + nan_policy='omit', + supports_out=False, + promotes_int_to_int64=True, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: nansum does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: nansum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: nansum does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + OpInfo( + "nn.functional.nll_loss", + ref=_NOTHING, + dtypesIfCPU=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_nll_loss, + skips=( + SkipInfo( + "TestJit", + "test_variant_consistency_jit", + dtypes=(torch.float32,), + ), + ), + ), ] # Common operator groupings unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] +binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)] spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse is True] shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)] +reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo)] # TODO: review porting these to make_tensor def index_variable(shape, max_indices, device=torch.device('cpu')): @@ -8769,8 +9450,8 @@ def _compare_trilu_indices( # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( torch.ones(row, col, device='cpu') - .tril(offset).nonzero().to(dtype).transpose(0, 1), - torch.tril_indices(row, col, offset, dtype=dtype, device=device)) + .triu(offset).nonzero().to(dtype).transpose(0, 1), + torch.triu_indices(row, col, offset, dtype=dtype, device=device)) def _compare_large_trilu_indices( diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 088e66f962592..a1059f6b718f4 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -2,11 +2,12 @@ from copy import deepcopy from functools import wraps, partial from itertools import chain -from torch.testing import floating_types +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import floating_types from torch.testing._internal.common_device_type import ( _TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf) -from torch.testing._internal.common_nn import nllloss_reference -from torch.testing._internal.common_utils import make_tensor +from torch.testing._internal.common_nn import nllloss_reference, get_reduction +from torch.testing._internal.common_utils import freeze_rng_state from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -46,6 +47,7 @@ class modules(_TestParametrizer): """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ + def __init__(self, module_info_list): self.module_info_list = module_info_list @@ -199,8 +201,103 @@ def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): return module_inputs +def no_batch_dim_reference_fn(m, p, *args, **kwargs): + """Reference function for modules supporting no batch dimensions. + + The module is passed the input and target in batched form with a single item. + The output is squeezed to compare with the no-batch input. + """ + single_batch_input_args = [input.unsqueeze(0) for input in args] + with freeze_rng_state(): + return m(*single_batch_input_args).squeeze(0) + + +def no_batch_dim_reference_criterion_fn(m, *args, **kwargs): + """Reference function for criterion supporting no batch dimensions.""" + output = no_batch_dim_reference_fn(m, *args, **kwargs) + reduction = get_reduction(m) + if reduction == 'none': + return output.squeeze(0) + # reduction is 'sum' or 'mean' which results in a 0D tensor + return output + + +def generate_regression_criterion_inputs(make_input): + return [ + ModuleInput( + constructor_input=FunctionInput(reduction=reduction), + forward_input=FunctionInput(make_input(shape=(4, )), make_input(shape=4,)), + reference_fn=no_batch_dim_reference_criterion_fn, + desc='no_batch_dim_{}'.format(reduction) + ) for reduction in ['none', 'mean', 'sum']] + + +def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(kernel_size=2), + forward_input=FunctionInput(make_input(shape=(3, 6))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(shape=(3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(shape=())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(shape=(3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(shape=(3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(shape=())), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(shape=(3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(shape=(2, 3, 4)), + make_input(shape=(2, 3, 4))), + reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum() + for a, b in zip(i, t))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(shape=()), make_input(shape=())), + reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(), + desc='scalar')] + generate_regression_criterion_inputs(make_input) + + # Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ + ModuleInfo(torch.nn.AvgPool1d, + module_inputs_func=module_inputs_torch_nn_AvgPool1d), + ModuleInfo(torch.nn.ELU, + module_inputs_func=module_inputs_torch_nn_ELU), + ModuleInfo(torch.nn.L1Loss, + module_inputs_func=module_inputs_torch_nn_L1Loss), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.NLLLoss, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index aeaf6616e28b1..b22b6ab1d2ec5 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -97,6 +97,7 @@ def get_weight(m): # - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True. # - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True. + module_tests = [ dict( module_name='Linear', @@ -1308,6 +1309,7 @@ def single_batch_reference_fn(input, parameters, module): with freeze_rng_state(): return module(single_batch_input).squeeze(0) + new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), @@ -2246,6 +2248,14 @@ def single_batch_reference_fn(input, parameters, module): cpp_constructor_args='torch::nn::LPPool1dOptions(2, 2).stride(3)', input_size=(1, 3, 7), ), + dict( + module_name='LPPool1d', + constructor_args=(2, 2, 3), + cpp_constructor_args='torch::nn::LPPool1dOptions(2, 2).stride(3)', + input_size=(3, 7), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), dict( module_name='LocalResponseNorm', constructor_args=(3, ), @@ -2764,6 +2774,14 @@ def single_batch_reference_fn(input, parameters, module): input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), check_gradgrad=False, ), + dict( + module_name='Embedding', + constructor_args=(4, 3), + cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)', + input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512), + check_gradgrad=False, + desc='discontiguous' + ), dict( module_name='EmbeddingBag', constructor_args=(4, 3), @@ -3641,6 +3659,28 @@ def single_batch_reference_fn(input, parameters, module): fullname='log_softmax_scalar', pickle=False, ), + dict( + module_name='Softmax2d', + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), + dict( + module_name='Softmax', + constructor_args=(-1,), + cpp_constructor_args='torch::nn::SoftmaxOptions(-1)', + input_size=(4, 5), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), + dict( + module_name='LogSoftmax', + constructor_args=(-1,), + cpp_constructor_args='torch::nn::LogSoftmaxOptions(1)', + input_size=(4, 5), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), dict( @@ -3801,6 +3841,14 @@ def single_batch_reference_fn(input, parameters, module): input_size=(), desc='scalar', ), + dict( + module_name='Softmin', + constructor_args=(-1,), + cpp_constructor_args='torch::nn::SoftminOptions(-1)', + input_size=(3, 4, 10), + reference_fn=single_batch_reference_fn, + desc='no_batch_dim', + ), dict( module_name='Tanhshrink', input_size=(), @@ -3967,6 +4015,22 @@ def single_batch_reference_fn(input, parameters, module): with_tf32=True, tf32_precision=0.005, ), + dict( + module_name='Flatten', + cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)', + constructor_args=(-3, -1), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + ), + dict( + module_name='Unflatten', + cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})', + constructor_args=(-2, torch.Size([2, 2])), + input_size=(3, 4, 5), + reference_fn=single_batch_reference_fn, + desc="no_batch_dim", + ), ] # add conv padding mode tests: @@ -4009,7 +4073,7 @@ def single_batch_reference_fn(input, parameters, module): # Check that non linear activations work with no batch dimensions non_linear_activations_no_batch = [ 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU', - 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', + 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU', 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh', 'Tanhshrink', 'Threshold' ] @@ -4025,7 +4089,7 @@ def single_batch_reference_fn(input, parameters, module): for non_linear_activation in non_linear_activations_no_batch: activation_test_info = dict( module_name=non_linear_activation, - input_size=(3,), + input_size=(4,), reference_fn=single_batch_reference_fn, desc='no_batch_dim', test_cpp_api_parity=False, @@ -4047,6 +4111,7 @@ def kldivloss_reference(input, target, reduction='mean'): return result.sum() / result.size(0) return result + def kldivloss_log_target_reference(input, target, reduction='mean'): result = torch.exp(target) * (target - input) if reduction == 'mean': @@ -4084,7 +4149,8 @@ def nlllossNd_reference(input, target, weight=None, ignore_index=-100, return output -def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean'): +def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean', + label_smoothing=0.0): assert input.dim() >= 2 input = torch.log_softmax(input, 1) @@ -4093,6 +4159,10 @@ def cross_entropy_loss_prob_target_reference(input, target, weight=None, reducti weight = torch.ones(C).type_as(input) weight = weight.view(1, C, *(1 for _ in input.shape[2:])) + if label_smoothing > 0.0: + assert label_smoothing <= 1.0 + target = (target * (1 - label_smoothing) + label_smoothing / C) + output = -(input * target * weight).sum(dim=1) if reduction == 'mean': return output.mean() @@ -4101,20 +4171,61 @@ def cross_entropy_loss_prob_target_reference(input, target, weight=None, reducti return output -def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean'): +def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100, + reduction='mean', label_smoothing=0.0): + log_softmax_input = torch.log_softmax(input, 1) + nllloss = F.nll_loss( + log_softmax_input, + target, + weight, + ignore_index=ignore_index, + reduction=reduction) + + if label_smoothing == 0.0: + return nllloss + + assert 0.0 < label_smoothing <= 1.0 + + input = torch.log_softmax(input, 1) + C = input.size(1) + if weight is not None: + input = input * weight.view(1, C, *(1 for _ in input.shape[2:])) + + smooth_loss = -torch.sum(input, 1) + + if ignore_index >= 0: + ignore_mask = target == ignore_index + smooth_loss.masked_fill_(ignore_mask, 0.0) + + if reduction == 'mean': + if weight is not None: + # TODO: This code can path can be removed if #61309 is resolved + # loss is normalized by the weights to be consistent with nll_loss_nd + ret = torch.sum(smooth_loss) / weight.gather(0, target.flatten()).sum() + else: + ret = torch.mean(smooth_loss) + elif reduction == 'sum': + ret = torch.sum(smooth_loss) + else: + ret = smooth_loss + + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C) + + +def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean', + label_smoothing=0.0): if input.shape == target.shape: return cross_entropy_loss_prob_target_reference( input, target, weight=weight, - reduction=reduction) + reduction=reduction, + label_smoothing=label_smoothing) else: - return nlllossNd_reference( - torch.log_softmax(input, 1), - target, - weight, - ignore_index=ignore_index, - reduction=reduction) + return cross_entropy_loss_indices_target_reference( + input, target, weight=weight, reduction=reduction, + ignore_index=ignore_index, label_smoothing=label_smoothing + ) def nllloss_reference(input, target, weight=None, ignore_index=-100, @@ -4874,6 +4985,141 @@ def padding3d_circular(input, pad): desc='4d_prob_target', check_bfloat16=False, ), + dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing_sum_reduction', + constructor=lambda *args, **kwargs: nn.CrossEntropyLoss(reduction='sum', + label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing_weight', + constructor_args_fn=lambda: (torch.rand(3).abs(),), + constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(3).abs())', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_prob_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', + label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(5, 3, 4), + target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_prob_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(5, 3, 4), + target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=1), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(1)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15, + ignore_index=1), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum).ignore_index(1)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=3), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(3)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=3), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_weight', + constructor_args_fn=lambda: (torch.rand(10).abs(),), + constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(10).abs())', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15), + check_bfloat16=False, + ), dict( module_name='CrossEntropyLoss', constructor_args_fn=lambda: (torch.rand(3),), @@ -5174,6 +5420,7 @@ def single_batch_reference_criterion_fn(*args): ('HingeEmbeddingLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)), ('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])), ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)), + ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)), ] classification_criterion_no_batch_extra_info: Dict[str, dict] = { 'MultiLabelMarginLoss': {'check_gradgrad': False}, @@ -5572,6 +5819,7 @@ def test_cuda(self, test_case): self.test_noncontig(test_case, gpu_module, gpu_input_tuple) + class InputVariableMixin(object): def _get_input(self): input = TestBase._get_input(self, False) # type: ignore[arg-type] @@ -5880,8 +6128,10 @@ def convert_dtype(obj, dtype, requires_grad=False): test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0) - cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) - gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) + cpu_gradInput = test_case._backward_criterion( + cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args) + gpu_gradInput = test_case._backward_criterion( + gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args) # dtype used to be able to be None, so set precision in this way instead of a precision map # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 2470b5392de11..77512f7ef445a 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.nn.intrinsic.quantized.dynamic as nniqd import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd from torch.nn.intrinsic import _FusedModule @@ -422,6 +423,13 @@ def checkDynamicQuantizedLinear(self, mod, dtype): self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) + def checkDynamicQuantizedLinearRelu(self, mod, dtype): + r"""Checks that mod has been swapped for an nnqd.Linear + module, the bias is float. + """ + self.assertEqual(type(mod), nniqd.LinearReLU) + self.assertEqual(mod._packed_params.dtype, dtype) + def check_eager_serialization(self, ref_model, loaded_model, x): # Check state dict serialization and torch.save APIs model_dict = ref_model.state_dict() @@ -975,12 +983,12 @@ def _compare_script_and_mobile(self, mobile_module_result = mobile_module(input) - torch.testing.assert_allclose(script_module_result, mobile_module_result) + torch.testing.assert_close(script_module_result, mobile_module_result) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_allclose(script_module_result, mobile_module_forward_result) + torch.testing.assert_close(script_module_result, mobile_module_forward_result) mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_allclose(script_module_result, mobile_module_run_method_result) + torch.testing.assert_close(script_module_result, mobile_module_run_method_result) except AssertionError as e: if retry == max_retry: raise e diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index fed9a005a55c5..0a265b52401b6 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -23,6 +23,7 @@ import random import contextlib import shutil +import threading from pathlib import Path import socket import subprocess @@ -43,13 +44,13 @@ import numpy as np -from torch.testing import floating_types_and, integral_types, complex_types, get_all_dtypes import expecttest from .._core import \ (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type) import torch import torch.cuda +from torch.testing import make_tensor from torch._utils_internal import get_writable_path from torch._six import string_classes from torch import Tensor @@ -156,7 +157,7 @@ def _get_test_report_path(): return os.path.join('test-reports', test_source) -parser = argparse.ArgumentParser(add_help=False) +parser = argparse.ArgumentParser() parser.add_argument('--subprocess', action='store_true', help='whether to run each test in a subprocess') parser.add_argument('--seed', type=int, default=1234) @@ -173,6 +174,15 @@ def _get_test_report_path(): parser.add_argument('--import-slow-tests', type=str, nargs='?', const=SLOW_TESTS_FILE) parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DISABLED_TESTS_FILE) +# Only run when -h or --help flag is active to display both unittest and parser help messages. +def run_unittest_help(argv): + unittest.main(argv=argv) + +if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() + args, remaining = parser.parse_known_args() if args.jit_executor == 'legacy': GRAPH_EXECUTOR = ProfilingMode.LEGACY @@ -906,8 +916,10 @@ def check_if_enable(test: unittest.TestCase): platform_to_conditional: Dict = { "mac": IS_MACOS, "macos": IS_MACOS, + "win": IS_WINDOWS, "windows": IS_WINDOWS, - "linux": IS_LINUX + "linux": IS_LINUX, + "rocm": TEST_WITH_ROCM } if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]): raise unittest.SkipTest( @@ -1937,103 +1949,7 @@ def f_retry(*args, **kwargs): return deco_retry -# Methods for matrix and tensor generation - -def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, high=None, - requires_grad: bool = False, noncontiguous: bool = False, - exclude_zero: bool = False) -> torch.Tensor: - """ Creates a random tensor with the given size, device and dtype. - - Default values for low and high: - * boolean type: low = 0, high = 2 - * uint8 type: low = 0, high = 9 - * floating and integral types: low = -9 and high = 9 - * complex types, for each real and imaginary part: low = -9, high = 9 - If low/high are specified and within dtype limits: low = low, high = high - If low/high are specified but exceed the limits: low = dtype_min, high = dtype_max - If low is -inf and/or high is inf: low = dtype_min, high = dtype_max - If low is inf or nan and/or high is -inf or nan: ValueError raised - - If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size - specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because - it is not possible to create a noncontiguous Tensor with a single element. - - If exclude_zero is passed with True (default is False), all the matching values (with zero) in - created tensor are replaced with a tiny (smallest positive representable number) value if floating type, - [`tiny` + `tiny`.j] if complex type and 1 if integer/boolean type. - """ - def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): - """ - Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. - """ - def clamp(a, l, h): - return min(max(a, l), h) - - low = low if low is not None else default_low - high = high if high is not None else default_high - - # Checks for error cases - if low != low or high != high: - raise ValueError("make_tensor: one of low or high was NaN!") - if low > high: - raise ValueError("make_tensor: low must be weakly less than high!") - - low = clamp(low, lowest, highest) - high = clamp(high, lowest, highest) - - if dtype in integral_types(): - return math.floor(low), math.ceil(high) - - return low, high - - if dtype is torch.bool: - result = torch.randint(0, 2, size, device=device, dtype=dtype) - elif dtype is torch.uint8: - ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) - low, high = _modify_low_high(low, high, ranges[0], ranges[1], 0, 9, dtype) - result = torch.randint(low, high, size, device=device, dtype=dtype) - elif dtype in integral_types(): - ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) - low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 9, dtype) - result = torch.randint(low, high, size, device=device, dtype=dtype) - elif dtype in floating_types_and(torch.half, torch.bfloat16): - ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) - low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) - rand_val = torch.rand(size, device=device, dtype=dtype) - result = high * rand_val + low * (1 - rand_val) - else: - assert dtype in complex_types() - float_dtype = torch.float if dtype is torch.cfloat else torch.double - ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) - low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) - real_rand_val = torch.rand(size, device=device, dtype=float_dtype) - imag_rand_val = torch.rand(size, device=device, dtype=float_dtype) - real = high * real_rand_val + low * (1 - real_rand_val) - imag = high * imag_rand_val + low * (1 - imag_rand_val) - result = torch.complex(real, imag) - - if noncontiguous and result.numel() > 1: - result = torch.repeat_interleave(result, 2, dim=-1) - result = result[..., ::2] - - if exclude_zero: - if dtype in integral_types() or dtype is torch.bool: - replace_with = torch.tensor(1, device=device, dtype=dtype) - elif dtype in floating_types_and(torch.half, torch.bfloat16): - replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) - elif dtype in complex_types(): - float_dtype = torch.float if dtype is torch.cfloat else torch.double - float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) - replace_with = torch.complex(float_eps, float_eps) - else: - raise ValueError(f"Invalid dtype passed, supported dtypes are: {get_all_dtypes()}") - result[result == 0] = replace_with - - if dtype in floating_types_and(torch.half, torch.bfloat16) or\ - dtype in complex_types(): - result.requires_grad = requires_grad - - return result +# Methods for matrix generation def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): assert rank <= l @@ -2533,13 +2449,6 @@ def disable_gc(): else: yield -def has_breakpad() -> bool: - # If not on a special build, check that the library was actually linked in - try: - torch._C._get_minidump_directory() # type: ignore[attr-defined] - return True - except RuntimeError as e: - return False def find_library_location(lib_name: str) -> Path: # return the shared library file in the installed folder if exist, @@ -2590,6 +2499,22 @@ def get_tensors_from(args, kwargs): return set([arg for arg in args if isinstance(arg, Tensor)] + [v for v in kwargs.values() if isinstance(v, Tensor)]) + +def has_breakpad(): + # We always build with breakpad in CI + if IS_IN_CI: + return True + + # If not on a special build, check that the library was actually linked in + try: + torch._C._get_minidump_directory() # type: ignore[attr-defined] + return True + except RuntimeError as e: + if "Minidump handler is uninintialized" in str(e): + return True + return False + + def sandcastle_skip_if(condition, reason): """ Similar to unittest.skipIf, however in the sandcastle environment it just diff --git a/torch/testing/_internal/dist_utils.py b/torch/testing/_internal/dist_utils.py index bdb21a7941c17..284a541444cdd 100644 --- a/torch/testing/_internal/dist_utils.py +++ b/torch/testing/_internal/dist_utils.py @@ -171,8 +171,6 @@ def wait_until_owners_and_forks_on_rank( def initialize_pg(init_method, rank: int, world_size: int) -> None: # This is for tests using `dist.barrier`. - # For `RpcAgent` other than `ProcessGroupAgent`, - # no `_default_pg` is initialized. if not dist.is_initialized(): dist.init_process_group( backend="gloo", diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 54a22b01bd667..613e23ede8f84 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -19,7 +19,6 @@ import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD import torch.distributed.algorithms.model_averaging.averagers as averagers import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils -import torch.distributed.algorithms.quantization as quant import torch.nn as nn import torch.nn.functional as F from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR @@ -29,7 +28,6 @@ from torch.distributed.algorithms.ddp_comm_hooks import ( quantization as quantization_hooks, ) -from torch.distributed.algorithms.quantization import DQuantType from torch.distributed.distributed_c10d import ( get_world_size, _get_default_group, @@ -66,14 +64,13 @@ sandcastle_skip_if, ) -if not IS_WINDOWS: - import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer - from torch.distributed.optim.functional_sgd import _FunctionalSGD - from torch.distributed.optim.functional_adam import _FunctionalAdam - _SUPPORTED_OPTIM_MAPPING = { - _FunctionalSGD: torch.optim.SGD, - _FunctionalAdam: torch.optim.Adam - } +from torch.distributed.optim import functional_optim_map + +from torch.distributed.optim.functional_sgd import _FunctionalSGD +from torch.distributed.optim.functional_adam import _FunctionalAdam +from torch.distributed.optim.functional_adamw import _FunctionalAdamW + +import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer from torch.utils.data.distributed import DistributedSampler @@ -2765,15 +2762,12 @@ def test_gather_full_group(self): # ALL GATHER def _test_all_gather_helper( - self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float ): for dest in group: tensor = _build_tensor(dest + 1, rank, dtype=dtype) tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group] - if qtype is not None: - allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None) - else: - allgather = dist.all_gather + allgather = dist.all_gather if cuda: tensor = tensor.cuda(rank_to_GPU[rank][0]) tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors] @@ -2839,12 +2833,6 @@ def test_all_gather_full_group(self): group, group_id, rank = self._init_full_group_test() self._test_all_gather_helper(group, group_id, rank) - @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors") - @sandcastle_skip_if(BACKEND == "mpi", "all_gather_quantized does not support MPI") - def test_all_gather_quantized(self): - group, group_id, rank = self._init_global_test() - self._test_all_gather_helper(group, group_id, rank, dtype=torch.float32, qtype=DQuantType.FP16) - def _run_all_gather_coalesced_and_verify( self, output_tensor_lists, input_tensors, expected_tensors, group_id ): @@ -3047,7 +3035,6 @@ def _test_all_to_all_helper( cuda=False, rank_to_GPU=None, dtype=torch.float, - qtype=None ): if group_id is not None: size = len(group) @@ -3068,11 +3055,7 @@ def _test_all_to_all_helper( t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors ] out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors] - if(qtype is not None): - quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None) - quantize_alltoall(out_tensors, in_tensors, group=group_id) - else: - dist.all_to_all(out_tensors, in_tensors, group=group_id) + dist.all_to_all(out_tensors, in_tensors, group=group_id) for t1, t2 in zip(out_tensors, expected_tensors): self.assertEqual(t1, t2) self._barrier() @@ -3155,20 +3138,6 @@ def test_all_to_all(self): group, group_id, rank = self._init_global_test() self._test_all_to_all_helper(group, group_id, rank) - @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports all_to_all") - @skip_if_rocm - def test_all_to_all_quantized(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = self._init_multigpu_helper() - self._test_all_to_all_helper( - group, - group_id, - rank, - cuda=True, - rank_to_GPU=rank_to_GPU, - dtype=torch.float32, - qtype=DQuantType.FP16) - @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all") @skip_if_rocm def test_all_to_all_cuda(self): @@ -3791,6 +3760,31 @@ def test_DistributedDataParallel_requires_grad(self): ) self._barrier() + @sandcastle_skip_if( + BACKEND == "nccl", + "Gloo-only test" + ) + def test_ddp_create_graph(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.p = nn.Parameter(torch.tensor(1.)) + + def forward(self): + return self.p.pow(2) + + model = Model() + ddp_model = torch.nn.parallel.DistributedDataParallel(model) + for _ in range(6): + # Verify DDP doesn't throw when ran with create_graph=True. + # Although we do warn about potential issues, please see + # https://github.com/pytorch/pytorch/issues/63929 for details. + ddp_model().backward(create_graph=True) + # grad tensors should require grad. + self.assertTrue( + all([param.requires_grad for param in ddp_model.parameters()]) + ) + @sandcastle_skip_if( BACKEND != "nccl" and BACKEND != "gloo", "Only NCCL and GLOO backend support DistributedDataParallel", @@ -3947,7 +3941,8 @@ def _test_ddp_hook_with_optimizer_parity( if static_graph: ddp_model_with_no_hook._set_static_graph() - optimizer_no_hook = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)( + mapping = {v: k for k, v in functional_optim_map.items()} + optimizer_no_hook = mapping.get(functional_optim_cls)( ddp_model_with_no_hook.parameters(), *functional_optim_args, **functional_optim_kwargs, @@ -4003,9 +3998,27 @@ def _test_ddp_hook_with_optimizer_parity( BACKEND != "nccl" and BACKEND != "gloo", "Only Nccl & Gloo backend support DistributedDataParallel", ) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_hook_with_optimizer_parity_adamw(self): + for grad_as_bucket_view, static_graph in itertools.product( + [True, False], [True, False] + ): + adamw_lr = 1e-2 + adamw_betas = (0.9, 0.99) + adamw_eps = 1e-6 + self._test_ddp_hook_with_optimizer_parity( + grad_as_bucket_view, + static_graph, + _FunctionalAdamW, + adamw_lr, + betas=adamw_betas, + eps=adamw_eps, + ) + @sandcastle_skip_if( - IS_WINDOWS, - "FunctionalAdam not yet supported with Windows, see https://github.com/pytorch/pytorch/issues/62137" + BACKEND != "nccl" and BACKEND != "gloo", + "Only Nccl & Gloo backend support DistributedDataParallel", ) @skip_if_lt_x_gpu(2) @skip_if_rocm @@ -4029,10 +4042,6 @@ def test_ddp_hook_with_optimizer_parity_adam(self): BACKEND != "nccl" and BACKEND != "gloo", "Only Nccl & Gloo backend support DistributedDataParallel", ) - @sandcastle_skip_if( - IS_WINDOWS, - "FunctionalSGD not yet supported with Windows, see https://github.com/pytorch/pytorch/issues/62137" - ) @skip_if_lt_x_gpu(2) @skip_if_rocm def test_ddp_hook_with_optimizer_parity_sgd(self): @@ -4093,20 +4102,13 @@ def _test_ddp_hook_parity(self, state, hook): grad_hook = net_with_hook.module.weight.grad avg_hook = grad_hook.clone() # Verify hook grad with expected. - # Cannot use exact match here due to a very small accuracy loss, - # e.g. 1e-05, for powerSGD hook case. - assert_func = ( - self.assertEqual - if hook == default.allreduce_hook - else torch.testing.assert_allclose - ) - assert_func( - avg_hook[0, 0], + self.assertEqual( + avg_hook[0, 0].item(), expected_grad, msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}", ) # Verify hook grad with vanilla allreduce - assert_func( + self.assertEqual( avg_hook[0, 0], avg[0, 0], msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}", @@ -4602,9 +4604,6 @@ def _test_DistributedDataParallel_SyncBatchNorm( BACKEND != "nccl" and BACKEND != "gloo", "Only NCCL and GLOO backend support DistributedDataParallel", ) - @sandcastle_skip_if( - IS_WINDOWS, "PostLocalSGDOptimizer not yet supported with Windows." - ) def test_post_localSGD_optimizer_parity(self, grad_is_view=False): learning_rate = 0.03 period = 4 @@ -4911,8 +4910,8 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( model.module.running_mean, model.module.running_var, ) - torch.testing.assert_allclose(running_mean, all_input_var.mean(1)) - torch.testing.assert_allclose(running_var, all_input_var.var(1)) + torch.testing.assert_close(running_mean, all_input_var.mean(1)) + torch.testing.assert_close(running_var, all_input_var.var(1)) @sandcastle_skip_if( BACKEND != "nccl" and BACKEND != "gloo", @@ -5075,6 +5074,12 @@ def parse_env(var): ddp_logging_data.get("gloo_device_transport"), parse_env("GLOO_DEVICE_TRANSPORT"), ) + default_gloo_threads = 2 + self.assertEqual( + ddp_logging_data.get("gloo_num_threads"), + default_gloo_threads, + ) + self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None) self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None) self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None) diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index fb1d5fbbc4f75..997006353bfbd 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -1,6 +1,5 @@ #!/usr/bin/python3 import enum -import pickle from typing import Tuple import torch @@ -467,7 +466,7 @@ def test_remote_module_py_pickle_not_supported_script(self): dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ): with TemporaryFileName() as fname: - with self.assertRaises(pickle.PickleError): + with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"): torch.save(remote_module, fname) diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index e50c30d4974b7..2ba25a591ae0f 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -64,13 +64,35 @@ def _torch_ones(sizes, requires_grad=False): # rref tensor equals to the given grad. def _compare_owner_value(context_id, rref, grad): grads = dist_autograd.get_gradients(context_id) - return torch.equal(grads[rref.local_value()], grad) + x = grads[rref.local_value()] + if x.is_sparse: + assert grad.is_sparse + x = x.to_dense() + grad = grad.to_dense() + else: + assert not grad.is_sparse + return torch.equal(x, grad) def create_tensor(): return torch.ones((3, 3), requires_grad=True) +def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32, device=None): + i = [[0, 1, 1], [2, 0, 2]] + v = [3.2, 4.1, 5.3] + tensor = torch.sparse_coo_tensor(i, v, (3, 3), requires_grad=requires_grad, dtype=dtype, device=device) + if coalesce: + tensor = tensor.coalesce() + return tensor + +def build_sparse_one_gradient(dtype=torch.float32): + i = [[0, 1, 1], [2, 0, 2]] + v = [1, 1, 1] + tensor = torch.sparse_coo_tensor(i, v, (3, 3), dtype=dtype) + return tensor + + @torch.jit.script def create_torchscript_tensor() -> torch.Tensor: return torch.ones((3, 3)).requires_grad_() @@ -88,6 +110,9 @@ def my_rref_add(rref_t1, t2): ret = torch.add(rref_t1.local_value(), t2) return ret +def my_sum(t): + return torch.sparse.sum(t) if t.is_sparse else t.sum() + @torch.jit.script def my_script_add(t1, t2): @@ -146,7 +171,8 @@ def _all_contexts_cleaned_up(timeout_seconds=10): def _run_trainer(rref_t1, t2, ps, rank_diff): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2)) - dist_autograd.backward(context_id, [ret.sum()]) + loss = my_sum(ret) + dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) @@ -156,7 +182,8 @@ def _run_trainer(rref_t1, t2, ps, rank_diff): def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff): with dist_autograd.context() as context_id: ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2)) - dist_autograd.backward(context_id, [ret.sum()]) + loss = my_sum(ret) + dist_autograd.backward(context_id, [loss]) # prevent deleting dist autograd context rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff)) rpc.rpc_sync(ps, _check_rpc_done, args=(0,)) @@ -379,14 +406,18 @@ def _verify_graph_for_nested_rpc_call(self, ctx): "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() ) - def _test_graph(self, fn, exec_mode): + def _test_graph(self, fn, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor() + t2 = build_sparse_tensor() + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2)) elif ExecMode.REMOTE == exec_mode: @@ -436,29 +467,49 @@ def _test_graph(self, fn, exec_mode): @dist_init def test_graph_for_builtin_call(self): - self._test_graph(torch.add, ExecMode.RPC_SYNC) + self._test_graph(torch.add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_builtin_call_sparse(self): + self._test_graph(torch.add, ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_python_call(self): - self._test_graph(my_py_add, ExecMode.RPC_SYNC) + self._test_graph(my_py_add, ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_python_call_sparse(self): + self._test_graph(my_py_add, ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_builtin_remote_call(self): - self._test_graph(torch.add, ExecMode.REMOTE) + self._test_graph(torch.add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_builtin_remote_call_sparse(self): + self._test_graph(torch.add, ExecMode.REMOTE, True) @dist_init def test_graph_for_python_remote_call(self): - self._test_graph(my_py_add, ExecMode.REMOTE) + self._test_graph(my_py_add, ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_python_remote_call_sparse(self): + self._test_graph(my_py_add, ExecMode.REMOTE, True) # 3-layer nested calls - def _test_graph_for_py_nested_call(self, exec_mode): + def _test_graph_for_py_nested_call(self, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) nest_dst_rank = (dst_rank + 1) % self.world_size if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( @@ -531,21 +582,33 @@ def _test_graph_for_py_nested_call(self, exec_mode): @dist_init def test_graph_for_py_nested_call(self): - self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC) + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_py_nested_remote_call(self): - self._test_graph_for_py_nested_call(ExecMode.REMOTE) + self._test_graph_for_py_nested_call(ExecMode.REMOTE, False) + + @dist_init + def test_graph_for_py_nested_remote_call_sparse(self): + self._test_graph_for_py_nested_call(ExecMode.REMOTE, True) # Rank0->Rank1->Rank0 - def _test_graph_for_py_nested_call_itself(self, exec_mode): + def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse): dst_rank = (self.rank + 1) % self.world_size initialize_pg(self.file_init_method, self.rank, self.world_size) with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=True) - t2 = torch.zeros(3, 3, requires_grad=True) + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones(3, 3, requires_grad=True) + t2 = torch.zeros(3, 3, requires_grad=True) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( worker_name(dst_rank), @@ -610,18 +673,30 @@ def _test_graph_for_py_nested_call_itself(self, exec_mode): @dist_init def test_graph_for_py_nested_call_itself(self): - self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC) + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False) + + @dist_init + def test_graph_for_py_nested_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True) @dist_init def test_graph_for_py_nested_remote_call_itself(self): - self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE) + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False) - def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): + @dist_init + def test_graph_for_py_nested_remote_call_itself_sparse(self): + self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True) + + def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): initialize_pg(self.file_init_method, self.rank, self.world_size) dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: - t1 = torch.ones(3, 3, requires_grad=False) - t2 = torch.zeros(3, 3, requires_grad=False) + if sparse: + t1 = build_sparse_tensor(requires_grad=False) + t2 = build_sparse_tensor(requires_grad=False) + else: + t1 = torch.ones(3, 3, requires_grad=False) + t2 = torch.zeros(3, 3, requires_grad=False) if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( worker_name(dst_rank), torch.add, args=(t1, t2) @@ -656,11 +731,19 @@ def _test_no_graph_with_tensors_not_require_grad(self, exec_mode): @dist_init def test_no_graph_with_tensors_not_require_grad(self): - self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC) + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True) @dist_init def test_no_graph_with_tensors_not_require_grad_remote(self): - self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE) + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False) + + @dist_init + def test_no_graph_with_tensors_not_require_grad_remote_sparse(self): + self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True) def _test_grad_only_on_return_value(self, exec_mode): initialize_pg(self.file_init_method, self.rank, self.world_size) @@ -699,13 +782,16 @@ def test_grad_only_on_return_value(self): def test_grad_only_on_return_value_remote(self): self._test_grad_only_on_return_value(ExecMode.REMOTE) - def _test_rpc_complex_args(self, exec_mode): + def _test_rpc_complex_args(self, exec_mode, sparse): with dist_autograd.context() as context_id: num_tensors = 10 tensors = [] for i in range(num_tensors): - tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) - + if sparse: + tensor = build_sparse_tensor(requires_grad=(i % 2 == 0)) + else: + tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0)) + tensors.append(tensor) dst_rank = self._next_rank() if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( @@ -739,11 +825,19 @@ def _test_rpc_complex_args(self, exec_mode): @dist_init def test_rpc_complex_args(self): - self._test_rpc_complex_args(ExecMode.RPC_SYNC) + self._test_rpc_complex_args(ExecMode.RPC_SYNC, False) + + @dist_init + def test_rpc_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.RPC_SYNC, True) @dist_init def test_remote_complex_args(self): - self._test_rpc_complex_args(ExecMode.REMOTE) + self._test_rpc_complex_args(ExecMode.REMOTE, False) + + @dist_init + def test_remote_complex_args_sparse(self): + self._test_rpc_complex_args(ExecMode.REMOTE, True) def context_cleanup_test_helper(self, rpc_args, func, nested=False): initialize_pg(self.file_init_method, self.rank, self.world_size) @@ -788,11 +882,22 @@ def test_context_cleanup_tensor_with_grad(self): t2 = torch.zeros(3, 3, requires_grad=True) self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + @dist_init + def test_context_cleanup_tensor_with_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add) + @dist_init def test_context_cleanup_tensor_no_grad(self): t1 = torch.ones(3, 3, requires_grad=False) self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) + @dist_init + def test_context_cleanup_tensor_no_grad_sparse(self): + t1 = build_sparse_tensor(requires_grad=False) + self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add) + @dist_init def test_context_cleanup_no_tensors(self): self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add) @@ -807,6 +912,16 @@ def test_context_cleanup_nested_rpc(self): rpc_args=args, func=my_py_nested_call, nested=True ) + @dist_init + def test_context_cleanup_nested_rpc_sparse(self): + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + dst_rank = (self.rank + 1) % self.world_size + args = (t1, t2, dst_rank, self.world_size, 0) + self.context_cleanup_test_helper( + rpc_args=args, func=my_py_nested_call, nested=True + ) + @dist_init def test_worker_ids_recorded(self): dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} @@ -876,23 +991,21 @@ def test_error_in_context(self): worker_name(self._next_rank()), torch.matmul, args=(t1, t2) ) - @dist_init - def test_backward_no_grad_on_tensor(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _backward_no_grad_on_tensor(self, t1, t2, sparse): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, - args=(t1, t2)).sum() - + args=(t1, t2)) + loss = my_sum(ret) dist_autograd.backward(context_id, [loss], retain_graph=True) self.assertIsNone(t1.grad) self.assertIsNone(t2.grad) # Now populate .grad with local autograd engine and # verify dist autograd doesn't mess with it. - loss_local = torch.add(t1, t2).sum() + ret = torch.add(t1, t2) + loss_local = my_sum(ret) loss_local.backward() self.assertIsNotNone(t1.grad) self.assertIsNotNone(t2.grad) @@ -903,18 +1016,31 @@ def test_backward_no_grad_on_tensor(self): self.assertEqual(t1_grad_before, t1.grad) self.assertEqual(t2_grad_before, t2.grad) - def _test_backward_simple(self, dst): - # Run the same code locally and with dist autograd and verify gradients - # are same. - local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + @dist_init + def test_backward_no_grad_on_tensor(self): + self._backward_no_grad_on_tensor( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + False + ) + + @dist_init + def test_backward_no_grad_on_tensor_sparse(self): + self._backward_no_grad_on_tensor( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + True + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple(self, dst, t1, t2, local_grads, sparse): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func_with_dst( dst, exec_mode, torch.add, t1, t2 ) - loss = ret.sum() + loss = my_sum(ret) ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @@ -922,29 +1048,62 @@ def _test_backward_simple(self, dst): @dist_init def test_backward_simple(self): - self._test_backward_simple(self._next_rank()) + self._backward_simple( + self._next_rank(), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_simple_sparse(self): + self._backward_simple( + self._next_rank(), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_simple_self(self): - self._test_backward_simple(self.rank) + self._backward_simple( + self.rank, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_simple_self_sparse(self): + self._backward_simple( + self.rank, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) # The current rank first creates a tensor on the rref_owner, and then passes # the rref with another tensor to the callee to run either my_rref_add or # my_nested_rref_add, depending on whether the callee is the rref owner. # The grad of tensor lives on the current rank, and the grad of the rref # tensor lives on the rref owner. - def _test_backward_rref(self, callee, rref_owner): - local_grads = None - t1 = torch.ones((3, 3), requires_grad=True) - t2 = torch.zeros((3, 3), requires_grad=True) - + def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse): local_ret = torch.add(t1, t2) - local_ret.sum().backward() + local_ret = my_sum(local_ret) + local_ret.backward() with dist_autograd.context() as context_id: - rref_t1 = rpc.remote( - rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True} - ) - + if sparse: + rref_t1 = rpc.remote( + rref_owner, build_sparse_tensor, args=(False, True,) + ) + else: + rref_t1 = rpc.remote( + rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True} + ) if callee == rref_owner: rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2)) else: @@ -952,7 +1111,8 @@ def _test_backward_rref(self, callee, rref_owner): callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2) ) ret = rref.to_here() - dist_autograd.backward(context_id, [ret.sum()]) + ret = my_sum(ret) + dist_autograd.backward(context_id, [ret]) # verify grads on caller grads = dist_autograd.get_gradients(context_id) @@ -972,20 +1132,81 @@ def _test_backward_rref(self, callee, rref_owner): def test_backward_rref(self): callee = worker_name(self._next_rank()) rref_owner = callee - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_sparse(self): + callee = worker_name(self._next_rank()) + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_rref_multi(self): if self.rank > 0: callee = "worker0" rref_owner = callee - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_multi_sparse(self): + if self.rank > 0: + callee = "worker0" + rref_owner = callee + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) @dist_init def test_backward_rref_nested(self): callee = worker_name((self.rank + 1) % self.world_size) rref_owner = worker_name((self.rank + 2) % self.world_size) - self._test_backward_rref(callee, rref_owner) + self._backward_rref( + callee, + rref_owner, + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3), requires_grad=True), + None, + False + ) + + @dist_init + def test_backward_rref_nested_sparse(self): + callee = worker_name((self.rank + 1) % self.world_size) + rref_owner = worker_name((self.rank + 2) % self.world_size) + self._backward_rref( + callee, + rref_owner, + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True), + None, + True + ) # In this test, every rank will serve as a parameter server (ps) and a # driver, and then kicks off trainers on the other three ranks. So, we have: @@ -996,13 +1217,16 @@ def test_backward_rref_nested(self): # # These four test ps-trainer groups run on completely separate autograd # graphs, but they share the same set of underlying RpcAgents. - def _test_trainer_ps(self, create_ref_fn, trainer_fn): - local_grads = None - t1 = torch.ones((3, 3), requires_grad=True) - t2 = torch.zeros((3, 3), requires_grad=True) + def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse): + if sparse: + t1 = build_sparse_tensor(requires_grad=True) + t2 = build_sparse_tensor(requires_grad=True) + else: + t1 = torch.ones((3, 3), requires_grad=True) + t2 = torch.zeros((3, 3), requires_grad=True) local_ret = torch.add(t1, t2) - local_ret.sum().backward() + my_sum(local_ret).backward() # create rref on self rref_t1 = rpc.remote( @@ -1045,7 +1269,19 @@ def _test_trainer_ps(self, create_ref_fn, trainer_fn): @dist_init def test_trainer_ps(self): - self._test_trainer_ps(create_tensor, _run_trainer) + self._test_trainer_ps( + create_tensor, + _run_trainer, + False + ) + + @dist_init + def test_trainer_ps_sparse(self): + self._test_trainer_ps( + build_sparse_tensor, + _run_trainer, + True + ) @dist_init def test_trainer_ps_torchscript_functions(self): @@ -1056,17 +1292,9 @@ def test_trainer_ps_torchscript_functions(self): import torch.distributed.rpc.api as api api._ignore_rref_leak = True - self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript) - - @dist_init - def test_backward_multiple_round_trips(self): - local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3)) - t3 = torch.rand((3, 3), requires_grad=True) - t4 = torch.rand((3, 3)) - t5 = torch.rand((3, 3), requires_grad=True) + self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False) + def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: # Multiple RPCs between different nodes. @@ -1074,15 +1302,42 @@ def test_backward_multiple_round_trips(self): val = self._exec_func(exec_mode, torch.mul, t3, val) s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) - val = self._exec_func(exec_mode, torch.bmm, s1, s2) - val = self._exec_func(exec_mode, torch.matmul, val, val) - loss = val.sum() + if s1.is_sparse: + val = self._exec_func(exec_mode, torch.mul, s1, s2) + val = self._exec_func(exec_mode, torch.mul, val, val) + loss = torch.sparse.sum(val) + else: + val = self._exec_func(exec_mode, torch.bmm, s1, s2) + val = self._exec_func(exec_mode, torch.matmul, val, val) + loss = val.sum() ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5 ) local_grads = ret if ret else local_grads + @dist_init + def test_backward_multiple_round_trips(self): + self._backward_multiple_round_trips( + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + torch.rand((3, 3)), + torch.rand((3, 3), requires_grad=True), + None + ) + + @dist_init + def test_backward_multiple_round_trips_sparse(self): + self._backward_multiple_round_trips( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False), + build_sparse_tensor(requires_grad=True), + None + ) + @dist_init def test_backward_different_tensor_dims(self): local_grads = None @@ -1317,41 +1572,60 @@ def test_backward_multiple_roots(self): exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2 ) - @dist_init - def test_backward_different_dtypes(self): + def _backward_different_dtypes(self, t1, t2): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True, dtype=torch.float32) - t2 = torch.rand((3, 3), requires_grad=True, dtype=torch.float64) for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: - loss = self._exec_func(exec_mode, torch.add, t1, t2).sum() - + loss = self._exec_func(exec_mode, torch.add, t1, t2) + loss = my_sum(loss) local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @dist_init - def test_backward_simple_python_udf(self): - # Run the same code locally and with dist autograd and verify gradients - # are same. + def test_backward_different_dtypes(self): + self._backward_different_dtypes( + torch.rand((3, 3), requires_grad=True, dtype=torch.float32), + torch.rand((3, 3), requires_grad=True, dtype=torch.float64) + ) + + @dist_init + def test_backward_different_dtypes_sparse(self): + self._backward_different_dtypes( + build_sparse_tensor(requires_grad=True, dtype=torch.float32), + build_sparse_tensor(requires_grad=True, dtype=torch.float64) + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_python_udf(self, t1, t2): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: with dist_autograd.context() as context_id: ret = self._exec_func(exec_mode, my_py_add, t1, t2) - loss = ret.sum() + loss = my_sum(ret) local_grads = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) @dist_init - def test_backward_simple_script_call(self): - # Run the same code locally and with dist autograd and verify gradients - # are same. + def test_backward_simple_python_udf(self): + self._backward_simple_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True) + ) + + @dist_init + def test_backward_simple_python_udf_sparse(self): + self._backward_simple_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True) + ) + + # Run the same code locally and with dist autograd and verify gradients + # are same. + def _backward_simple_script_call(self, t1, t2): local_grads = None - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) for exec_mode in [ ExecMode.LOCAL, ExecMode.RPC_SYNC, @@ -1360,12 +1634,26 @@ def test_backward_simple_script_call(self): ]: with dist_autograd.context() as context_id: forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2) - loss = forward_ret.sum() + loss = my_sum(forward_ret) ret = self._verify_backwards( exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads + @dist_init + def test_backward_simple_script_call(self): + self._backward_simple_script_call( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True) + ) + + @dist_init + def test_backward_simple_script_call_sparse(self): + self._backward_simple_script_call( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True) + ) + @staticmethod def _complex_python_udf(t1, t2): t3 = torch.nn.functional.linear(t1, t2) @@ -1463,8 +1751,7 @@ def test_backward_node_failure_python_udf(self): dist_autograd.backward(context_id, [res.sum()]) # Mark rank 0 is done in the store, since the RPC framework on - # some nodes might be broken at this point (listenLoop() in - # ProcessGroupAgent might've exited). + # some nodes might be broken at this point. store.set('test_backward_node_failure_python_udf_rank0_done', "True") else: # Wait for backward to finish on rank 0. @@ -1475,32 +1762,45 @@ def _nested_python_udf(t1, t2, dst): t3 = t1 * t2 t4 = t1 + t2 res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) - return torch.linalg.multi_dot([t1, t2, t3, t4, res]) + return t1 * t2 * t3 * t4 * res - @dist_init - def test_backwards_nested_python_udf(self): - # Run equivalent of _nested_python_udf locally. - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _backwards_nested_python_udf(self, t1, t2): t3 = t1 * t2 t4 = t1 + t2 res = t3 + t4 - loss = torch.linalg.multi_dot([t1, t2, t3, t4, res]).sum() + ret = t1 * t2 * t3 * t4 * res + loss = my_sum(ret) torch.autograd.backward([loss]) # Now run distributed autograd. with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), DistAutogradTest._nested_python_udf, args=(t1, t2, self._next_rank()), ) - dist_autograd.backward(context_id, [loss.sum()]) - + loss = my_sum(ret) + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(t1.grad, grads[t1]) self.assertEqual(t2.grad, grads[t2]) + @dist_init + def test_backwards_nested_python_udf(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True) + ) + + @dist_init + def test_backwards_nested_python_udf_sparse(self): + # Run equivalent of _nested_python_udf locally. + self._backwards_nested_python_udf( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True) + ) + _test_clean_context_backward_context_id = None class MyBackwardFunc(Function): @@ -1595,8 +1895,7 @@ def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weigh def _get_grad(cls, embedding_rref, context_id): embedding = embedding_rref.local_value() grad_map = dist_autograd.get_gradients(context_id) - # Can't send sparse tensors over RPC: https://github.com/pytorch/pytorch/issues/30807 - return grad_map[embedding.weight].to_dense() + return grad_map[embedding.weight] @dist_init def test_embedding_bag_with_no_grad_tensors(self): @@ -1638,26 +1937,24 @@ def test_embedding_bag_with_no_grad_tensors(self): args=(remote_embedding, context_id), ) - self.assertEqual(local_grad.to_dense(), remote_grad) + self.assertEqual(local_grad, remote_grad) @classmethod - def _mixed_requires_grad(cls, t1, t2): + def _mixed_requires_grad_operaton(cls, t1, t2): if t2.requires_grad: return t1 - t2 else: return t1 * t2 - @dist_init - def test_mixed_requires_grad(self): + def _mixed_requires_grad(self, t1, t2): for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]: - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=False) with dist_autograd.context() as context_id: ret = self._exec_func( - exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2 + exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2 ) self.assertEqual(t1 * t2, ret) - dist_autograd.backward(context_id, [ret.sum()]) + loss = my_sum(ret) + dist_autograd.backward(context_id, [loss]) self.assertTrue(t1.requires_grad) self.assertFalse(t2.requires_grad) grads = dist_autograd.get_gradients(context_id) @@ -1665,6 +1962,20 @@ def test_mixed_requires_grad(self): self.assertNotIn(t2, grads) self.assertEqual(t2, grads[t1]) + @dist_init + def test_mixed_requires_grad(self): + self._mixed_requires_grad( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=False) + ) + + @dist_init + def test_mixed_requires_grad_sparse(self): + self._mixed_requires_grad( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=False) + ) + class TestDebugInfoFunc(Function): @staticmethod def forward(ctx, input): @@ -1802,37 +2113,59 @@ def test_backward_accumulate_grads(self): @staticmethod def _test_nested_backward_accumulate_grads(t1, t2, dst_rank): - return rpc.rpc_sync(worker_name(dst_rank), torch.matmul, args=(t1, t2)) + return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2)) - @dist_init - def test_nested_backward_accumulate_grads(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def _nested_backward_accumulate_grads(self, t1, t2): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), DistAutogradTest._test_nested_backward_accumulate_grads, args=(t1, t2, self._next_rank()), - ).sum() - + ) + loss = my_sum(ret) # Run backward twice. dist_autograd.backward(context_id, [loss], retain_graph=True) dist_autograd.backward(context_id, [loss]) @dist_init - def test_multiple_backward(self): - t1 = torch.rand((3, 3), requires_grad=True) - t2 = torch.rand((3, 3), requires_grad=True) + def test_nested_backward_accumulate_grads(self): + self._nested_backward_accumulate_grads( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True) + ) + + @dist_init + def test_nested_backward_accumulate_grads_sparse(self): + self._nested_backward_accumulate_grads( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True) + ) + + def _multiple_backward(self, t1, t2): with dist_autograd.context() as context_id: - loss = rpc.rpc_sync( + ret = rpc.rpc_sync( worker_name(self._next_rank()), torch.add, - args=(t1, t2)).sum() - + args=(t1, t2)) + loss = my_sum(ret) # Run backward in a loop multiple times. for i in range(1000): dist_autograd.backward(context_id, [loss], retain_graph=True) + @dist_init + def test_multiple_backward(self): + self._multiple_backward( + torch.rand(3, 3, requires_grad=True), + torch.rand(3, 3, requires_grad=True) + ) + + @dist_init + def test_multiple_backward_sparse(self): + self._multiple_backward( + build_sparse_tensor(requires_grad=True), + build_sparse_tensor(requires_grad=True) + ) + @dist_init(clean_shutdown=False) def test_multiple_backward_with_errors(self): initialize_pg(self.file_init_method, self.rank, self.world_size) @@ -2136,15 +2469,13 @@ def test_thread_local_context_id(self): class CudaDistAutogradTest(CommonDistAutogradTest): - @skip_if_lt_x_gpu(1) - @dist_init - def test_gpu_simple(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - (t1 + t2).sum().backward() + + def _gpu_simple(self, t1, t2): + my_sum(t1 + t2).backward() with dist_autograd.context() as context_id: t3 = t1 + t2 - dist_autograd.backward(context_id, [t3.sum()]) + loss = my_sum(t3) + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) self.assertEqual(2, len(grads)) self.assertEqual(t1.grad, grads[t1]) @@ -2152,9 +2483,22 @@ def test_gpu_simple(self): @skip_if_lt_x_gpu(1) @dist_init - def test_gpu_to_cpu_continuation(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True) + def test_gpu_simple(self): + self._gpu_simple( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True, device="cuda:0") + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_simple_sparse(self): + self._gpu_simple( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True, device="cuda:0") + ) + + + def _gpu_to_cpu_continuation(self, t1, t2): # Run a few iterations. for i in range(3): t1.grad = None @@ -2169,16 +2513,29 @@ def test_gpu_to_cpu_continuation(self): t6 = t5.cuda(0) + t4 t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5) # Autograd graph consists of CPU -> GPU -> CPU execution. + loss = my_sum(t7) ret = self._verify_backwards( - exec_mode, [t7.sum()], context_id, local_grads, t1, t2 + exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads @skip_if_lt_x_gpu(1) @dist_init - def test_gpu_to_cpu_continuation_gpu_root(self): - t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") - t2 = torch.rand(3, 3, requires_grad=True) + def test_gpu_to_cpu_continuation(self): + self._gpu_to_cpu_continuation( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True) + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_sparse(self): + self._gpu_to_cpu_continuation( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True) + ) + + def _gpu_to_cpu_continuation_gpu_root(self, t1, t2): # Run a few iterations. for i in range(3): t1.grad = None @@ -2192,11 +2549,28 @@ def test_gpu_to_cpu_continuation_gpu_root(self): t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2) t6 = t5.cuda(0) + t4 # Autograd graph consists of CPU -> GPU -> CPU execution. + loss = my_sum(t6) ret = self._verify_backwards( - exec_mode, [t6.sum()], context_id, local_grads, t1, t2 + exec_mode, [loss], context_id, local_grads, t1, t2 ) local_grads = ret if ret else local_grads + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_gpu_root(self): + self._gpu_to_cpu_continuation_gpu_root( + torch.rand(3, 3, requires_grad=True, device="cuda:0"), + torch.rand(3, 3, requires_grad=True) + ) + + @skip_if_lt_x_gpu(1) + @dist_init + def test_gpu_to_cpu_continuation_gpu_root_sparse(self): + self._gpu_to_cpu_continuation_gpu_root( + build_sparse_tensor(requires_grad=True, device="cuda:0"), + build_sparse_tensor(requires_grad=True) + ) + class FaultyAgentDistAutogradTest(RpcAgentTestFixture): # Reusing a simplified helper function from DistAutogradTest to ensure @@ -2258,8 +2632,7 @@ def gradients(self, ctx_id): class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture): - @skip_if_lt_x_gpu(4) - def test_device_maps_backward_pass(self): + def _device_maps_backward_pass(self, t1, t2): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) @@ -2274,19 +2647,36 @@ def test_device_maps_backward_pass(self): rpc_backend_options=options, ) - t1 = torch.rand(10, device=self.rank, requires_grad=True) - t2 = torch.rand(10, device=self.rank, requires_grad=True) with dist_autograd.context() as context_id: res = rpc.rpc_sync(dst, torch.add, args=(t1, t2)) - dist_autograd.backward(context_id, [res.sum()]) + loss = my_sum(res) + dist_autograd.backward(context_id, [loss]) grads = dist_autograd.get_gradients(context_id) - self.assertEqual(torch.ones(10), grads[t1]) - self.assertEqual(torch.ones(10), grads[t2]) + if t1.is_sparse: + self.assertEqual(build_sparse_one_gradient(), grads[t1]) + self.assertEqual(build_sparse_one_gradient(), grads[t2]) + else: + self.assertEqual(torch.ones(10), grads[t1]) + self.assertEqual(torch.ones(10), grads[t2]) self.assertEqual(t1.device, grads[t1].device) self.assertEqual(t2.device, grads[t2].device) rpc.shutdown() + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass(self): + self._device_maps_backward_pass( + torch.rand(10, requires_grad=True, device=self.rank), + torch.ones(10, requires_grad=True, device=self.rank) + ) + + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass_sparse(self): + self._device_maps_backward_pass( + build_sparse_tensor(requires_grad=True, device=self.rank), + build_sparse_tensor(requires_grad=True, device=self.rank) + ) + class MyRemoteCompute(torch.nn.Module): def __init__(self): super().__init__() @@ -2303,9 +2693,7 @@ def __init__(self, next_stage): def forward(self, input): return self.next_stage.rpc_sync().forward(input) - @skip_if_lt_x_gpu(4) - def test_dist_autograd_sync_streams(self): - + def _dist_autograd_sync_streams(self, sparse): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) @@ -2323,17 +2711,20 @@ def test_dist_autograd_sync_streams(self): remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute) local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute) for _ in range(10): - input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) + if sparse: + input = build_sparse_tensor(requires_grad=True, device=self.rank) + else: + input = torch.rand([1000, 10000], device=self.rank, requires_grad=True) # Run local autograd result = input * 2.0 r = random.random() - loss = result.sum() * r + loss = my_sum(result) * r loss.backward() # Run distributed autograd with dist_autograd.context() as context_id: result = local_compute(input) - loss = result.sum() * r + loss = my_sum(result) * r dist_autograd.backward(context_id, [loss]) # Compare grads. @@ -2343,7 +2734,14 @@ def test_dist_autograd_sync_streams(self): rpc.shutdown() @skip_if_lt_x_gpu(4) - def test_gradients_synchronizations(self): + def test_dist_autograd_sync_streams(self): + self._dist_autograd_sync_streams(False) + + @skip_if_lt_x_gpu(4) + def test_dist_autograd_sync_streams_sparse(self): + self._dist_autograd_sync_streams(True) + + def _gradients_synchronizations(self, x): options = self.rpc_backend_options for peer_rank in range(self.world_size): options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank}) @@ -2367,8 +2765,8 @@ def test_gradients_synchronizations(self): WrapperModule, args=(layers[rank - 1], rank) )) + x = x.to(0) - x = torch.randn(5000, 2000).to(0) # local iteration local_model = nn.Sequential(*local_layers) local_model(x).sum().backward() @@ -2390,3 +2788,15 @@ def test_gradients_synchronizations(self): self.assertEqual(g1, g2) rpc.shutdown() + + @skip_if_lt_x_gpu(4) + def test_gradients_synchronizations(self): + self._gradients_synchronizations( + torch.randn(5000, 2000) + ) + + @skip_if_lt_x_gpu(4) + def test_gradients_synchronizations_sparse(self): + self._gradients_synchronizations( + torch.randn(5000, 2000).to_sparse() + ) diff --git a/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py b/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py index ae151137a4705..24f7ab81c5594 100644 --- a/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py +++ b/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py @@ -50,8 +50,6 @@ def setup_fault_injection(self, faulty_messages, messages_to_delay): def get_shutdown_error_regex(self): error_regexes = [ - "Encountered exception in ProcessGroupAgent::enqueueSend", - "Encountered exception in ProcessGroupAgent::listenLoop()", "Exception in thread pool task", "Connection reset by peer", "Connection closed by peer" diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index c95b7216c4a67..23759f1e292ad 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -194,6 +194,14 @@ def my_slow_method(self, my_tensor_arg): return torch.add(self.a, my_tensor_arg) +def _run_func_in_mode(to, fn, mode, args=None, kwargs=None): + if mode == RPCExecMode.SYNC: + return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) + elif mode == RPCExecMode.ASYNC: + return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() + elif mode == RPCExecMode.REMOTE: + return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() + def _call_method_on_rref(method, rref, *args, **kwargs): return method(rref.local_value(), *args, **kwargs) @@ -209,10 +217,13 @@ def add_rref_to_value(rref, value): def run_nested_pickle(pickle_cls_instance, tensor): return pickle_cls_instance.t + tensor -def build_sparse_tensor(): +def build_sparse_tensor(coalesce=False): i = [[0, 1, 1], [2, 0, 2]] v = [3, 4, 5] - return torch.sparse_coo_tensor(i, v, (2, 3)) + tensor = torch.sparse_coo_tensor(i, v, (2, 3)) + if coalesce: + tensor = tensor.coalesce() + return tensor def build_complex_tensors(): a = torch.ones(3, 3) @@ -238,6 +249,12 @@ def my_function(a, b, c): def my_tensor_function(a, b): return a + b +def my_container_sum(a): + result = a[0] + for tensor in a[1:]: + result += tensor + return result + def my_sleep_func(seconds=1): time.sleep(seconds) @@ -275,6 +292,14 @@ def nested_rpc(dst): return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) +def nested_rpc_sparse(dst): + return rpc.rpc_sync( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ) + + def multi_layer_nested_async_rpc(dst, world_size, ttl): # this method returns immediately without blocking the callee, but will # generate additional requests. @@ -296,10 +321,29 @@ def nested_rref(dst): ) +def nested_rref_sparse(dst): + return ( + rpc.remote( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ), + rpc.remote( + dst, + torch.add, + args=(build_sparse_tensor(), build_sparse_tensor()) + ), + ) + + def nested_remote(dst): rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) return rref.to_here() +def nested_remote_sparse(dst): + rref = rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())) + return rref.to_here() + def rref_forward_chain(dst, world_size, rref, ttl): if ttl > 0: @@ -328,6 +372,12 @@ def heavy_rpc(tensor): return 0 +def heavy_rpc_sparse(tensor): + for i in range(1, 100): + tensor *= i + tensor = tensor / (i + 1) + return 0 + @torch.jit.script def heavy_rpc_torchscript(tensor): for i in range(1, 100): @@ -600,6 +650,57 @@ def __init__(self, init_method): load_tests = load_tests +class MyEmbeddingBagModel(torch.nn.Module): + def __init__(self, sparse): + super().__init__() + self.eb = torch.nn.EmbeddingBag( + 10, + 10, + sparse=sparse + ) + + def forward(self, x): + return self.eb(x) + + +class MyParameterServer: + def __init__(self, trainers): + self.lock = Lock() + self.trainers = trainers + self.iteration = 0 + self.updates = 0 + self.futures = [] + self.total = None + self.gradient = None + + @staticmethod + def get_gradient(rref): + return rref.local_value().gradient + + @staticmethod + @rpc.functions.async_execution + def average(rref, riteration, tensor): + self = rref.local_value() + fut = torch.futures.Future() + with self.lock: + if riteration > self.iteration: + self.iteration = riteration + self.updates = 0 + self.futures.clear() + self.futures.append(fut) + if self.total is None: + self.total = tensor + else: + self.total += tensor + self.updates += 1 + if self.trainers == self.updates: + self.gradient = self.total / float(self.trainers) + for fut in self.futures: + result = self.total / float(self.trainers) + fut.set_result(result) + return fut + + class RpcTest(RpcAgentTestFixture): @dist_init def test_worker_id(self): @@ -641,62 +742,157 @@ def test_self_add(self): def test_send_to_rank(self): dst_rank = (self.rank + 1) % self.world_size + # Test dense tensor for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: - ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) self.assertEqual(ret, torch.ones(2, 2) + 1) + # Test sparse tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor() + y = build_sparse_tensor() + expected_tensor = (x + y) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor(coalesce=True) + y = build_sparse_tensor(coalesce=True) + expected_tensor = (x + y) + ret = _run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y)) + self.assertEqual(expected_tensor, ret) + # Test invalid ranks for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(RuntimeError): - self._run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(self.world_size + 1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(RuntimeError): - self._run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(-1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(ValueError): - self._run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: with self.assertRaises(ValueError): - self._run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + _run_func_in_mode(dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)) + + def _self_py_udf_remote(self, worker_info, x, y, z): + rref = rpc.remote(worker_info, my_function, args=(x, y, z)) + self.assertEqual(rref.to_here(), x + y + z) @dist_init def test_self_py_udf_remote(self): - self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1 + 3) + self._self_py_udf_remote( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_py_udf_remote_sparse(self): + self._self_py_udf_remote( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) - def _test_self_remote_rref_as_rpc_arg(self, dst): + + def _self_remote_rref_as_rpc_arg(self, dst, x, y, z): self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, torch.ones(2, 2))) - ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, torch.ones(2, 2) + 1)) - self.assertEqual(ret, torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + 1) - self.assertEqual(fut.wait(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2)) + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x)) + ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y)) + self.assertEqual(ret, x + y + z + x + y) + self.assertEqual(fut.wait(), x + y + z + x) @dist_init def test_self_remote_rref_as_rpc_arg(self): dst = worker_name((self.rank + 1) % self.world_size) - self._test_self_remote_rref_as_rpc_arg(dst) + self._self_remote_rref_as_rpc_arg( + dst, + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_rpc_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_rpc_arg( + dst, + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) @dist_init def test_self_remote_rref_as_self_rpc_arg(self): - self._test_self_remote_rref_as_rpc_arg(rpc.get_worker_info()) + self._self_remote_rref_as_rpc_arg( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_self_rpc_arg_sparse(self): + self._self_remote_rref_as_rpc_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) - def _test_self_remote_rref_as_remote_arg(self, dst): + def _self_remote_rref_as_remote_arg(self, dst, x, y, z): self_worker_info = rpc.get_worker_info() - rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3)) - ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2))) + rref = rpc.remote(self_worker_info, my_function, args=(x, y, z)) + ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x)) self.assertEqual( - ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2) + ret_rref.to_here(), x + y + z + x ) @dist_init def test_self_remote_rref_as_remote_arg(self): dst = worker_name((self.rank + 1) % self.world_size) - self._test_self_remote_rref_as_remote_arg(dst) + self._self_remote_rref_as_remote_arg( + dst, + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_remote_arg_sparse(self): + dst = worker_name((self.rank + 1) % self.world_size) + self._self_remote_rref_as_remote_arg( + dst, + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) + + @dist_init + def test_self_remote_rref_as_self_remote_arg(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), + torch.ones(2, 2), + 1, + 3 + ) + + @dist_init + def test_self_remote_rref_as_self_remote_arg_sparse(self): + self._self_remote_rref_as_remote_arg( + rpc.get_worker_info(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + ) @dist_init def test_rref_proxy_non_exist(self): @@ -816,10 +1012,6 @@ def test_rref_proxy_class(self): def test_rref_proxy_class_self(self): self._test_rref_proxy_class(rpc.get_worker_info()) - @dist_init - def test_self_remote_rref_as_self_remote_arg(self): - self._test_self_remote_rref_as_remote_arg(rpc.get_worker_info()) - @mock.patch.object(torch.distributed.autograd, "_init") @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent") @dist_init(setup_rpc=False) @@ -911,7 +1103,7 @@ def test_reinit(self): ) rpc.shutdown() - def test_world_size_one(self): + def _world_size_one(self, a, b): if self.rank == 0: rpc.init_rpc( name="me", @@ -921,32 +1113,51 @@ def test_world_size_one(self): rpc_backend_options=self.rpc_backend_options, ) - expect = torch.ones(2, 2) * 2 - result = rpc.rpc_sync( - "me", - my_tensor_function, - args=(torch.ones(2, 2), torch.ones(2, 2)) - ) - self.assertEqual(expect, result) - - expect = torch.ones(3, 3) * 2 - result = rpc.rpc_async( - "me", - my_tensor_function, - args=(torch.ones(3, 3), torch.ones(3, 3)) - ).wait() - self.assertEqual(expect, result) + def _rpc_sync(x, y): + expect = x * 2 + result = rpc.rpc_sync( + "me", + my_tensor_function, + args=(x, y) + ) + self.assertEqual(expect, result) + + def _rpc_async(x, y): + expect = x * 2 + result = rpc.rpc_async( + "me", + my_tensor_function, + args=(x, y) + ).wait() + self.assertEqual(expect, result) + + def _remote(x, y): + expect = x * 2 + result = rpc.remote( + "me", + my_tensor_function, + args=(x, y) + ).to_here() + self.assertEqual(expect, result) - expect = torch.ones(4, 4) * 2 - result = rpc.remote( - "me", - my_tensor_function, - args=(torch.ones(4, 4), torch.ones(4, 4)) - ).to_here() - self.assertEqual(expect, result) + _rpc_sync(a, b) + _rpc_async(a, b) + _remote(a, b) rpc.shutdown() + def test_world_size_one(self): + self._world_size_one( + torch.ones(2, 2), + torch.ones(2, 2) + ) + + def test_world_size_one_sparse(self): + self._world_size_one( + build_sparse_tensor(), + build_sparse_tensor() + ) + @dist_init(setup_rpc=False) def test_invalid_names(self): from torch.distributed.rpc import WorkerInfo @@ -1027,19 +1238,46 @@ def test_nonzero(self): ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,)) self.assertEqual(ret, x.nonzero()) - @dist_init - def test_multi_rpc(self): + def _multi_rpc(self, sparse): dst_rank = (self.rank + 1) % self.world_size for i in range(20): n = i + self.rank + 1 + if sparse: + x = build_sparse_tensor() * n + y = build_sparse_tensor() * n + else: + x = torch.ones(2, 2) + y = torch.ones(2, 2) ret = rpc.rpc_sync( worker_name(dst_rank), torch.add, - args=(torch.ones(n, n), torch.ones(n, n)), + args=(x, y), ) - self.assertEqual(ret, torch.ones(n, n) * 2) + self.assertEqual(ret, x * 2) + + @dist_init + def test_multi_rpc(self): + self._multi_rpc(False) + + @dist_init + def test_multi_rpc_sparse(self): + self._multi_rpc(True) + + @dist_init + def test_future_wait_twice(self): + dst = worker_name((self.rank + 1) % self.world_size) + futs = [] + for i in range(20): + futs.append(rpc.rpc_async(dst, raise_func)) - def _run_uneven_workload(self, num_repeat=30): + with self.assertRaisesRegex(ValueError, "Expected error"): + torch.futures.wait_all(futs) + + for fut in futs: + with self.assertRaisesRegex(ValueError, "Expected error"): + fut.wait() + + def _run_uneven_workload(self, f, x, num_repeat=30): # worker0 drives and waits for worker1 and worker2 # throughout the test. if self.rank == 0: @@ -1049,7 +1287,7 @@ def _run_uneven_workload(self, num_repeat=30): dst = "worker1" futs = [] for _ in range(num_repeat): - fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + fut = rpc.rpc_async(dst, f, args=(x,)) futs.append(fut) for fut in torch.futures.collect_all(futs).wait(): @@ -1061,13 +1299,13 @@ def _run_uneven_workload(self, num_repeat=30): dst = "worker2" futs = [] for _ in range(num_repeat): - fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + fut = rpc.rpc_async(dst, f, args=(x,)) futs.append(fut) for val in torch.futures.wait_all(futs): self.assertEqual(val, 0) - def test_wait_all_workers(self): + def _wait_all_workers(self, f, x): initialize_pg(self.file_init_method, self.rank, self.world_size) rpc.init_rpc( name="worker%d" % self.rank, @@ -1077,7 +1315,7 @@ def test_wait_all_workers(self): rpc_backend_options=self.rpc_backend_options, ) - self._run_uneven_workload() + self._run_uneven_workload(f, x) # worker0 calls this at the end after waiting for RPC responses. # worker1/2 calls this immediately and has some works after it. @@ -1089,7 +1327,13 @@ def test_wait_all_workers(self): dist.barrier() rpc.shutdown(graceful=False) - def test_wait_all_workers_twice(self): + def test_wait_all_workers_dense(self): + self._wait_all_workers(heavy_rpc, torch.ones(100, 100)) + + def test_wait_all_workers_sparse(self): + self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor()) + + def _wait_all_workers_twice(self, f, x): initialize_pg(self.file_init_method, self.rank, self.world_size) rpc.init_rpc( name="worker%d" % self.rank, @@ -1099,7 +1343,7 @@ def test_wait_all_workers_twice(self): rpc_backend_options=self.rpc_backend_options, ) - self._run_uneven_workload() + self._run_uneven_workload(f, x) # worker0 calls this at the end after waiting for RPC responses. # worker1/2 calls this immediately and has some works after it. @@ -1112,6 +1356,12 @@ def test_wait_all_workers_twice(self): dist.barrier() rpc.shutdown(graceful=False) + def test_wait_all_workers_twice_dense(self): + self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100)) + + def test_wait_all_workers_twice_sparse(self): + self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor()) + @dist_init def test_all_gather(self): info = rpc.get_worker_info() @@ -1197,7 +1447,7 @@ def test_rpc_barrier_multithreaded(self): @dist_init def test_graceful_shutdown_with_uneven_workload(self): """Test graceful termination.""" - self._run_uneven_workload() + self._run_uneven_workload(heavy_rpc, torch.ones(100, 100)) @dist_init(setup_rpc=False) def test_shutdown_followed_by_rpc(self): @@ -2067,6 +2317,16 @@ def test_py_tensors_in_container(self): ) self.assertEqual(ret, my_complex_tensor_function(a, b, c)) + @dist_init + def test_py_sparse_tensors_in_container(self): + n = self.rank + 1 + dst_rank = n % self.world_size + a = [build_sparse_tensor(), build_sparse_tensor()] + ret = rpc.rpc_sync( + worker_name(dst_rank), my_container_sum, args=(a,) + ) + self.assertEqual(ret, my_container_sum(a)) + @dist_init def test_py_nested_pickle(self): n = self.rank + 1 @@ -2123,16 +2383,23 @@ def test_py_raise_in_user_func_escaped_str(self): else: self.assertTrue(False, "expected raise_func_escape to raise ValueError.") - @dist_init - def test_nested_rpc(self): + def _nested_rpc(self, f, expected): n = self.rank + 1 dst_rank = n % self.world_size ret = rpc.rpc_sync( worker_name(dst_rank), - nested_rpc, + f, args=(worker_name(self.rank),), ) - self.assertEqual(ret, torch.ones(2, 2) + 1) + self.assertEqual(ret, expected) + + @dist_init + def test_nested_rpc(self): + self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1) + + @dist_init + def test_nested_rpc_sparse(self): + self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2) def _stress_test_rpc(self, f, repeat=1000, args=()): n = self.rank + 1 @@ -2160,31 +2427,65 @@ def test_stress_light_rpc(self): def test_stress_heavy_rpc(self): self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) + @dist_init + def test_stress_heavy_rpc_sparse(self): + self._stress_test_rpc(heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)) + @dist_init def test_stress_heavy_rpc_torchscript(self): self._stress_test_rpc(heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)) - @dist_init - def test_builtin_remote_ret(self): + def _builtin_remote_ret(self, x, y, expected): n = self.rank + 1 dst_rank = n % self.world_size rref = rpc.remote( worker_name(dst_rank), torch.add, - args=(torch.ones(n, n), torch.ones(n, n)), + args=(x, y), ) - self.assertEqual(rref.to_here(), torch.ones(n, n) * 2) + self.assertEqual(rref.to_here(), expected) @dist_init - def test_builtin_remote_self(self): + def test_builtin_remote_ret(self): + self._builtin_remote_ret( + torch.ones(2, 2), + torch.ones(2, 2), + torch.ones(2, 2) * 2 + ) + + @dist_init + def test_builtin_remote_ret_sparse(self): + self._builtin_remote_ret( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 2 + ) + + def _builtin_remote_self(self, x, y, expected): rref = rpc.remote( worker_name(self.rank), torch.add, - args=(torch.ones(2, 2), torch.ones(2, 2)), + args=(x, y), + ) + self.assertEqual(rref.local_value(), expected) + + @dist_init + def test_builtin_remote_self(self): + self._builtin_remote_self( + torch.ones(2, 2), + torch.ones(2, 2), + torch.ones(2, 2) * 2 + ) + + @dist_init + def test_builtin_remote_self_sparse(self): + self._builtin_remote_self( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 2 ) - self.assertEqual(rref.local_value(), torch.ones(2, 2) * 2) - def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}): + def _test_multi_remote_call(self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}): m = 10 n = self.rank + 1 dst_rank = n % self.world_size @@ -2196,21 +2497,35 @@ def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: rpc.remote( worker_name(dst_rank), fn, - args=args_fn(n), - kwargs=kwargs_fn(n), + args=args_fn(n, sparse), + kwargs=kwargs_fn(n, sparse), ) ) - expected.append(fn(*args_fn(n), **kwargs_fn(n))) + expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse))) for i in range(m): self.assertEqual(rrefs[i].to_here(), expected[i]) + @staticmethod + def _multi_args_fn(n, sparse=False): + if sparse: + return (build_sparse_tensor(), build_sparse_tensor()) + else: + return (torch.ones(n, n), torch.ones(n, n)) + @dist_init def test_multi_builtin_remote_ret(self): - def args_fn(n): - return (torch.ones(n, n), torch.ones(n, n)) + self._test_multi_remote_call( + torch.add, False, + args_fn=RpcTest._multi_args_fn + ) - self._test_multi_remote_call(torch.add, args_fn=args_fn) + @dist_init + def test_multi_builtin_remote_ret_sparse(self): + self._test_multi_remote_call( + torch.add, True, + args_fn=RpcTest._multi_args_fn + ) @dist_init def test_py_udf_remote(self): @@ -2223,82 +2538,177 @@ def test_py_udf_remote(self): ) self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) - @dist_init - def test_multi_py_udf_remote(self): - def kwargs_fn(n): + @staticmethod + def _multi_kwargs_fn(n, sparse=False): + if sparse: + return { + "a": build_sparse_tensor(), + "b": build_sparse_tensor(), + "c": build_sparse_tensor() + } + else: return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} - self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn) + @dist_init + def test_multi_py_udf_remote(self): + self._test_multi_remote_call( + my_function, + False, + kwargs_fn=RpcTest._multi_kwargs_fn + ) @dist_init - def test_py_rref_args(self): + def test_multi_py_udf_remote_sparse(self): + self._test_multi_remote_call( + my_function, + True, + kwargs_fn=RpcTest._multi_kwargs_fn + ) + + def _py_rref_args(self, a, b, x, y, expected): n = self.rank + 1 dst_rank = n % self.world_size rref_a = rpc.remote( - worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2) + worker_name(dst_rank), torch.add, args=(a, b) ) rref_b = rpc.remote( - worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1) + worker_name(dst_rank), torch.add, args=(x, y) ) rref_c = rpc.remote( worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) ) - self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + self.assertEqual(rref_c.to_here(), expected) @dist_init - def test_py_rref_args_user_share(self): + def test_py_rref_args(self): + self._py_rref_args( + torch.ones(2, 2), + 1, + torch.ones(2, 2), + 2, + torch.ones(2, 2) * 2 + 3) + + @dist_init + def test_py_rref_args_sparse(self): + self._py_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 4 + ) + + def _py_rref_args_user_share(self, a, b, c, x, y, z, expected): n = self.rank + 1 owner_rank = n % self.world_size user_rank = (n + 1) % self.world_size rref_a = rpc.remote( - worker_name(owner_rank), my_function, args=(torch.ones(n, n), 2, 0) + worker_name(owner_rank), my_function, args=(a, b, c) ) rref_b = rpc.remote( - worker_name(owner_rank), my_function, args=(torch.ones(n, n), 1, 0) + worker_name(owner_rank), my_function, args=(x, y, z) ) rref_c = rpc.remote( worker_name(user_rank), my_rref_function, args=(rref_a, rref_b) ) - self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + self.assertEqual(rref_c.to_here(), expected) @dist_init - def test_py_rpc_rref_args(self): + def test_py_rref_args_user_share(self): + self._py_rref_args_user_share( + torch.ones(2, 2), + 1, + 2, + torch.ones(2, 2), + 3, + 4, + torch.ones(2, 2) * 2 + 10 + ) + + @dist_init + def test_py_rref_args_user_share_sparse(self): + self._py_rref_args_user_share( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6 + ) + + def _py_rpc_rref_args(self, a, b, c, x, y, z, expected): n = self.rank + 1 dst_rank = n % self.world_size rref_a = rpc.remote( - worker_name(dst_rank), my_function, args=(torch.ones(n, n), 2, 0) + worker_name(dst_rank), my_function, args=(a, b, c) ) rref_b = rpc.remote( - worker_name(dst_rank), my_function, args=(torch.ones(n, n), 1, 0) + worker_name(dst_rank), my_function, args=(x, y, z) ) c = rpc.rpc_sync( worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b) ) + self.assertEqual(c, expected) - self.assertEqual(c, torch.ones(n, n) + 4) + @dist_init + def test_py_rpc_rref_args(self): + self._py_rpc_rref_args( + torch.ones(2, 2), + 1, + 2, + torch.ones(2, 2), + 3, + 4, + torch.ones(2, 2) * 2 + 10 + ) @dist_init - def test_nested_remote(self): + def test_py_rpc_rref_args_sparse(self): + self._py_rpc_rref_args( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() * 6 + ) + + def _nested_remote(self, f, expected): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref = rpc.remote( worker_name(dst_rank1), - nested_remote, + f, args=(worker_name(dst_rank2),), ) - self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3) + self.assertEqual(rref.to_here(), expected) @dist_init - def test_nested_rref(self): + def test_nested_remote(self): + self._nested_remote( + nested_remote, + torch.ones(2, 2) + 3 + ) + + @dist_init + def test_nested_remote_sparse(self): + self._nested_remote( + nested_remote_sparse, + build_sparse_tensor() + build_sparse_tensor() + ) + + def _nested_rref(self, f, expected1, expected2): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size rref_of_rrefs = rpc.remote( worker_name(dst_rank1), - nested_rref, + f, args=(worker_name(dst_rank2),), ) @@ -2308,11 +2718,26 @@ def test_nested_rref(self): rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) - self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) - self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) @dist_init - def test_nested_rref_stress(self): + def test_nested_rref(self): + self._nested_rref( + nested_rref, + torch.ones(2, 2) + 1, + torch.ones(2, 2) + 2 + ) + + @dist_init + def test_nested_rref_sparse(self): + self._nested_rref( + nested_rref_sparse, + build_sparse_tensor() * 2, + build_sparse_tensor() * 2 + ) + + def _nested_rref_stress(self, f, expected1, expected2): n = self.rank + 1 dst_rank1 = n % self.world_size dst_rank2 = (n + 1) % self.world_size @@ -2321,7 +2746,7 @@ def test_nested_rref_stress(self): all_rrefs.append( rpc.remote( worker_name(dst_rank1), - nested_rref, + f, args=(worker_name(dst_rank2),), ) ) @@ -2330,8 +2755,24 @@ def test_nested_rref_stress(self): rref_of_rrefs = all_rrefs[i] rrefs = rref_of_rrefs.to_here() self.assertEqual(len(rrefs), 2) - self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) - self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + self.assertEqual(rrefs[0].to_here(), expected1) + self.assertEqual(rrefs[1].to_here(), expected2) + + @dist_init + def test_nested_rref_stress(self): + self._nested_rref_stress( + nested_rref, + torch.ones(2, 2) + 1, + torch.ones(2, 2) + 2 + ) + + @dist_init + def test_nested_rref_stress_sparse(self): + self._nested_rref_stress( + nested_rref_sparse, + build_sparse_tensor() * 2, + build_sparse_tensor() * 2 + ) @dist_init def test_multi_layer_nested_async_rpc(self): @@ -2883,7 +3324,7 @@ def test_handle_send_exceptions(self): ) rpc._set_rpc_timeout(10) # This barrier is needed to ensure that some workers do not exit before - # others have been brought up, for non ProcessGroupAgent backends. + # others have been brought up. initialize_pg(self.file_init_method, self.rank, self.world_size) dist.barrier() if self.rank == 1: @@ -3210,7 +3651,7 @@ def test_function_not_on_callee(self): # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error. self.assertTrue(hasattr(this_module, "foo_add")) with self.assertRaisesRegex( - AttributeError, "RPC pickler does not serialize" + RuntimeError, "RPC pickler does not serialize" ): rpc.rpc_sync(callee_worker, foo_add, args=()) @@ -3592,17 +4033,9 @@ def test_future_in_rpc(self): def test_future_nested_callback(self): self._test_future_cb(add_use_future_nested_cb) - def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None): - if mode == RPCExecMode.SYNC: - return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs) - elif mode == RPCExecMode.ASYNC: - return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait() - elif mode == RPCExecMode.REMOTE: - return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here() - def _test_async_function_raise(self, mode): with self.assertRaisesRegex(RuntimeError, "Expected error"): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), async_raise_func, mode @@ -3626,7 +4059,7 @@ def _test_async_function_wrong_return_type(self, mode): "torch\\.futures\\.Future object," ) with self.assertRaisesRegex(RuntimeError, errMsg): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode @@ -3657,7 +4090,7 @@ def _test_async_function(self, fn, mode=RPCExecMode.SYNC): dst2 = worker_name((self.rank + 2) % self.world_size) args = (dst2, torch.ones(2, 2), 1, 2) - ret = self._run_func_in_mode(dst1, fn, mode, args=args) + ret = _run_func_in_mode(dst1, fn, mode, args=args) self.assertEqual(ret, torch.ones(2, 2) + 3) @dist_init @@ -3750,7 +4183,7 @@ def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC): num = 20 step = 3 args = (dst2, torch.ones(2, 2), num, step) - ret = self._run_func_in_mode(dst1, fn, mode, args=args) + ret = _run_func_in_mode(dst1, fn, mode, args=args) self.assertEqual(ret, torch.ones(2, 2) + num * step) @dist_init @@ -3794,7 +4227,7 @@ def _test_return_future(self, mode): RuntimeError, "Can not pickle torch.futures.Future" ): - self._run_func_in_mode( + _run_func_in_mode( worker_name((self.rank + 1) % self.world_size), return_future, mode @@ -4096,6 +4529,46 @@ def rref_error(): dist.barrier() + def _trainer_func(self, rref, sparse): + m = MyEmbeddingBagModel(sparse=sparse) + loss_fn = nn.MSELoss() + for i in range(10): + outputs = m(torch.rand(10, 10).long()) + loss_fn(outputs, torch.rand(10, 10)).backward() + gradient = list(m.parameters())[0].grad + fut = rref.rpc_async().average(rref, i, gradient) + gradient = fut.wait() + if gradient.is_sparse: + gradient = gradient.to_dense().double() + ps_gradient = rref.rpc_sync().get_gradient(rref) + if ps_gradient.is_sparse: + ps_gradient = ps_gradient.to_dense().double() + self.assertTrue(torch.equal(gradient, ps_gradient)) + + def _my_parameter_server(self, sparse): + ps_rref = RRef(MyParameterServer(self.world_size - 1)) + futures = [] + for index in range(1, self.world_size): + futures.append( + rpc.rpc_async( + worker_name((self.rank + index) % self.world_size), + self._trainer_func, + args=( + ps_rref, + sparse + ), + ) + ) + torch.futures.wait_all(futures) + + @dist_init + def test_my_parameter_server(self): + self._my_parameter_server(False) + + @dist_init + def test_my_parameter_server_sparse(self): + self._my_parameter_server(True) + class CudaRpcTest(RpcAgentTestFixture): @@ -4744,13 +5217,33 @@ def test_device_maps_gpu(self): rpc_backend_options=options, ) - ret = rpc.rpc_sync( - dst, - TensorPipeAgentCudaRpcTest._gpu_add, - args=(torch.zeros(2).to(0), torch.ones(2).to(0)) - ) - self.assertEqual(ret.device, torch.device(1)) - self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1)) + # Test dense tensor + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = torch.ones(2, 2) + y = torch.ones(2, 2) + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + + # Test sparse tensor uncoalesced + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor() + y = build_sparse_tensor() + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + + # Test sparse tensor coalesced + for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]: + x = build_sparse_tensor().coalesce() + y = build_sparse_tensor().coalesce() + expected_tensor = (x + y) + ret = _run_func_in_mode(dst, TensorPipeAgentCudaRpcTest._gpu_add, exec_mode, args=(x.to(0), y.to(0))) + self.assertEqual(ret.device, torch.device(1)) + self.assertEqual(ret, expected_tensor.to(1)) + rpc.shutdown() @staticmethod @@ -5249,8 +5742,7 @@ def test_device_maps_missing_config_remote(self): def test_device_maps_missing_config_remote_response(self): self._test_device_maps_missing_config_response(RPCExecMode.REMOTE) - @skip_if_lt_x_gpu(2) - def test_device_maps_remote(self): + def _device_maps_remote(self, x, y, expected): options = self.rpc_backend_options dst = worker_name((self.rank + 1) % self.world_size) options.set_device_map(dst, {1: 0}) @@ -5266,14 +5758,29 @@ def test_device_maps_remote(self): rref = rpc.remote( dst, TensorPipeAgentCudaRpcTest._add_to_gpu, - args=(torch.zeros(2), 1) + args=(x, y) ) - self.assertEqual(rref.to_here().device.index, 1) - self.assertEqual(rref.to_here(), torch.ones(2).to(1)) + self.assertEqual(rref.to_here(), expected.to(1)) rpc.shutdown() + @skip_if_lt_x_gpu(2) + def test_device_maps_remote(self): + self._device_maps_remote( + torch.ones(3, 3), + torch.ones(3, 3), + torch.ones(3, 3) + torch.ones(3, 3) + ) + + @skip_if_lt_x_gpu(2) + def test_device_maps_remote_sparse(self): + self._device_maps_remote( + build_sparse_tensor(), + build_sparse_tensor(), + build_sparse_tensor() + build_sparse_tensor() + ) + @staticmethod def _slow_add_on_user_stream(x, y): s0 = torch.cuda.current_stream(x.device) diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index b5cf9f73548c1..dd11c8dc450e0 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -2,13 +2,11 @@ import os import sys import unittest -from enum import Flag, auto from typing import Dict, List, Type from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, - TEST_WITH_TSAN, find_free_port, IS_SANDCASTLE, ) @@ -75,25 +73,12 @@ def _check_and_unset_tcp_init(): # The tests for the RPC module need to cover multiple possible combinations: # - different aspects of the API, each one having its own suite of tests; # - different agents (ProcessGroup, TensorPipe, ...); -# - and subprocesses launched with either fork or spawn. # To avoid a combinatorial explosion in code size, and to prevent forgetting to # add a combination, these are generated automatically by the code in this file. -# Here, we collect all the test suites that we need to cover and the two multi- -# processing methods. We then have one separate file for each agent, from which +# Here, we collect all the test suites that we need to cover. +# We then have one separate file for each agent, from which # we call the generate_tests function of this file, passing to it a fixture for -# the agent, which then gets mixed-in with each test suite and each mp method. - - -@unittest.skipIf(TEST_WITH_TSAN, "TSAN and fork() is broken") -class ForkHelper(MultiProcessTestCase): - def setUp(self): - super().setUp() - _check_and_set_tcp_init() - self._fork_processes() - - def tearDown(self): - _check_and_unset_tcp_init() - super().tearDown() +# the agent, which then gets mixed-in with each test suite. @unittest.skipIf( TEST_WITH_DEV_DBG_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues" @@ -109,17 +94,6 @@ def tearDown(self): super().tearDown() -class MultiProcess(Flag): - FORK = auto() - SPAWN = auto() - - -MP_HELPERS_AND_SUFFIXES = { - MultiProcess.FORK: (ForkHelper, "WithFork"), - MultiProcess.SPAWN: (SpawnHelper, "WithSpawn"), -} - - # This list contains test suites that are agent-agnostic and that only verify # compliance with the generic RPC interface specification. These tests should # *not* make use of implementation details of a specific agent (options, @@ -175,7 +149,6 @@ def generate_tests( prefix: str, mixin: Type[RpcAgentTestFixture], tests: List[Type[RpcAgentTestFixture]], - mp_type_filter: MultiProcess, module_name: str, ) -> Dict[str, Type[RpcAgentTestFixture]]: """Mix in the classes needed to autogenerate the tests based on the params. @@ -183,36 +156,25 @@ def generate_tests( Takes a series of test suites, each written against a "generic" agent (i.e., derived from the abstract RpcAgentTestFixture class), as the `tests` args. Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a - certain agent, as the `mixin` arg. Produces all combinations of them, and of - the multiprocessing start methods (fork or spawn), possibly filtered using - the `mp_type_filter`. Returns a dictionary of class names to class type + certain agent, as the `mixin` arg. Produces all combinations of them. + Returns a dictionary of class names to class type objects which can be inserted into the global namespace of the calling - module. The name of each test will be a concatenation of the `prefix` arg, - the original name of the test suite, and a suffix of either `WithFork` or - `WithSpawn`. The `module_name` should be the name of the calling module so + module. The name of each test will be a concatenation of the `prefix` arg + and the original name of the test suite. + The `module_name` should be the name of the calling module so that the classes can be fixed to make it look like they belong to it, which is necessary for pickling to work on them. """ ret: Dict[str, Type[RpcAgentTestFixture]] = {} for test_class in tests: - for mp_type in MultiProcess: - if mp_type & mp_type_filter: - mp_helper, suffix = MP_HELPERS_AND_SUFFIXES[mp_type] - if IS_SANDCASTLE: - if mp_helper == SpawnHelper and TEST_WITH_DEV_DBG_ASAN: - print( - f'Skipping test {test_class} on sandcastle for the following reason: ' - 'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr) - continue - elif mp_helper == ForkHelper and TEST_WITH_TSAN: - print( - f'Skipping test {test_class} on sandcastle for the following reason: ' - 'TSAN and fork() is broken' - ) - continue - - name = f"{prefix}{test_class.__name__}{suffix}" - class_ = type(name, (test_class, mixin, mp_helper), dict()) - class_.__module__ = module_name - ret[name] = class_ + if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN: + print( + f'Skipping test {test_class} on sandcastle for the following reason: ' + 'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr) + continue + + name = f"{prefix}{test_class.__name__}" + class_ = type(name, (test_class, mixin, SpawnHelper), dict()) + class_.__module__ = module_name + ret[name] = class_ return ret diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index a21717bc5f9a1..75b1615d065d5 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -109,18 +109,49 @@ ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ), - '', (False, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), + 'training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (0, S, S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), + (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), + 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), True, ), + 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, True, ), + 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, None, False, ), + 'inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), + 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + None, non_differentiable(torch.ones(S)), False, ), + 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), + non_differentiable(torch.randn(S)), None, False, ), + 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), ('layer_norm', (S, S, S, S), ([5],), '', - (True, ['aten::native_layer_norm'])), + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', - (True, ['aten::native_layer_norm'])), + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', - (True, ['aten::native_layer_norm'])), + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), non_differentiable(torch.rand(S))), 'with_weight_and_bias', - (True, ['aten::native_layer_norm'])), + (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), ('group_norm', (S, S, S), (1, torch.rand(5),),), ('local_response_norm', (S, S, S), (2, ),), ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 6086572039033..4c521a8e4d9d5 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -375,35 +375,53 @@ def assertRaisesRegexWithHighlight(self, exception, regex, highlight): return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight) def checkScriptRaisesRegex(self, script, inputs, exception, regex, - outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING): + name=None, outputs=None, capture_output=False, + frames_up=1, profiling=ProfilingMode.PROFILING): """ Checks that a given function will throw the correct exception, - when executed with normal python, the string frontend, and the AST frontend + when executed with normal python, the string frontend, and the + AST frontend. Logic taken from `checkScript` (see comments there + for details) """ - with enable_profiling_mode_for_profiling_tests(): - # normal python + # Normal Python with self.assertRaisesRegex(exception, regex): - script(*inputs) - # string frontend + if isinstance(script, str): + frame = self.get_frame_vars(frames_up) + the_locals: Dict[str, Any] = {} + execWrapper(script, glob=frame, loc=the_locals) + frame.update(the_locals) + + python_fn = frame[name] + else: + python_fn = script + + python_fn(*inputs) + + # String frontend with self.assertRaisesRegex(exception, regex): - source = textwrap.dedent(inspect.getsource(script)) - cu = torch.jit.CompilationUnit(source) - ge = getattr(cu, script.__name__) - # profiling run + if isinstance(script, str): + cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) + string_frontend = getattr(cu, name) + else: + source = textwrap.dedent(inspect.getsource(script)) + cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) + string_frontend = getattr(cu, script.__name__) + with self.assertRaisesRegex(exception, regex): - ge(*inputs) + string_frontend(*inputs) # optimized run - ge(*inputs) - # python AST frontend - with self.assertRaisesRegex(exception, regex): - ge = torch.jit.script(script) - # profiling run + string_frontend(*inputs) + + # Python AST frontend + if not isinstance(script, str): with self.assertRaisesRegex(exception, regex): + ge = torch.jit.script(python_fn) + # profiling run + with self.assertRaisesRegex(exception, regex): + ge(*inputs) + # optimized run ge(*inputs) - # optimized run - ge(*inputs) - def checkBailouts(self, model, inputs, expected): state = model.get_debug_state() @@ -594,7 +612,7 @@ def input_reduce(input, fn, acc): for g2, g2_ge in zip(grads2, grads2_ge): if g2 is None and g2_ge is None: continue - self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4)) + self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4) return ge @@ -668,11 +686,13 @@ def wrapper(func): def enable_cpu_fuser(fn): def wrapper(*args, **kwargs): + torch._C._jit_override_can_fuse_on_cpu_legacy(True) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_set_te_must_use_llvm_cpu(False) try: fn(*args, **kwargs) finally: + torch._C._jit_override_can_fuse_on_cpu_legacy(False) torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_set_te_must_use_llvm_cpu(True) return wrapper diff --git a/torch/testing/_internal/opinfo_helper.py b/torch/testing/_internal/opinfo_helper.py index 5129af4f99e34..e4a18b48ca7a1 100644 --- a/torch/testing/_internal/opinfo_helper.py +++ b/torch/testing/_internal/opinfo_helper.py @@ -4,21 +4,22 @@ import torch from torch.testing._internal.common_cuda import (TEST_CUDA) -from torch.testing._core import _dispatch_dtypes -from torch.testing import (all_types_and_complex_and, - all_types_and_complex, - all_types_and_half, - all_types, - complex_types, - floating_and_complex_types, - floating_types_and_half, - floating_types, - integral_types, - floating_types_and, - floating_and_complex_types_and, - integral_types_and, - all_types_and, - ) +from torch.testing._internal.common_dtype import ( + all_types_and_complex_and, + all_types_and_complex, + all_types_and_half, + all_types, + complex_types, + floating_and_complex_types, + floating_types_and_half, + floating_types, + integral_types, + floating_types_and, + floating_and_complex_types_and, + integral_types_and, + all_types_and, + _dispatch_dtypes, +) COMPLETE_DTYPES_DISPATCH = ( all_types, diff --git a/torch/utils/_crash_handler.py b/torch/utils/_crash_handler.py index 3d736c3f85ce0..84b345229bde9 100644 --- a/torch/utils/_crash_handler.py +++ b/torch/utils/_crash_handler.py @@ -5,11 +5,10 @@ import torch DEFAULT_MINIDUMP_DIR = "/tmp/pytorch_crashes" +if sys.platform == "win32": + DEFAULT_MINIDUMP_DIR = str(pathlib.Path.home() / "AppData" / "pytorch_crashes") def enable_minidumps(directory=DEFAULT_MINIDUMP_DIR): - if sys.platform != "linux": - raise RuntimeError("Minidump collection is currently only implemented for Linux platforms") - if directory == DEFAULT_MINIDUMP_DIR: pathlib.Path(directory).mkdir(parents=True, exist_ok=True) elif not os.path.exists(directory): diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index bce658b997255..8a6d466f20da4 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -21,13 +21,18 @@ class InflatableArg(NamedTuple): the appropriate input. It can use 'value' as an input to the format str. It must result in a value of the same type as 'value'. + 'fmt_fn' is a formatable function code string that is executed to inflate the compressed + data into the appropriate input. It must result in a value of the same type as 'value'. + The function name should be the formatable part of the string. + Note: Only top level InflatableArgs can be inflated. i.e. you cannot place an inflatable arg inside of some other structure. You should instead create an inflatable arg such that the fmt code string returns the full structure of your input. """ value: Any - fmt: str + fmt: str = "{}" + fmt_fn: str = "" def bundle_inputs( @@ -279,13 +284,21 @@ def augment_many_model_functions_with_bundled_inputs( deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): - deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]") + inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name) + deflated, inflater, helper_definition = _inflate_expr( + arg, + f"deflated[{inp_idx}][{arg_idx}]", + inflate_helper_fn_name, + ) deflated_args.append(deflated) parts.append(f" {inflater},") + if helper_definition: + model.define(textwrap.dedent(helper_definition)) deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) + # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) @@ -332,7 +345,6 @@ def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs_for_forward()) """)) - # Define some high level helper methods that act on all bundled inputs model.define(textwrap.dedent(""" def get_bundled_inputs_functions_and_info(self): @@ -341,27 +353,44 @@ def get_bundled_inputs_functions_and_info(self): return all_inputs """.format(template=get_bundled_inputs_functions_and_info_template))) -def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]: +def _inflate_expr( + arg: T, ref: str, inflate_helper_fn_name: str +) -> Tuple[Union[T, torch.Tensor], str, Optional[str]]: # Allow custom inflation expressions any object. # For example, calling custom image-decoding ops. # Or just use "{}" as the format string to ignore size limits. if isinstance(arg, InflatableArg): - return arg.value, arg.fmt.format(ref) + if arg.fmt_fn: + if arg.fmt not in ["{}", ""]: + raise Exception( + f"Bundled input argument at position '{ref}' has " + f"both arg.fmt_fn => \n{arg.fmt_fn} " + f"\n and arg.fmt => {arg.fmt}. " + "Please choose `arg.fmt` if the deflater is straightforward or " + "`arg.fmt_fn` if you need a function." + ) + + helper_definition = arg.fmt_fn.format(inflate_helper_fn_name) + expr = f"self.{inflate_helper_fn_name}({ref})" + + return arg.value, expr, helper_definition + else: + return arg.value, arg.fmt.format(ref), None if isinstance(arg, torch.Tensor): # Small-storage tensors can just be saved directly. if arg.storage().size() <= MAX_RAW_TENSOR_SIZE: - return arg, ref + return arg, ref, None # Small contiguous tensors can be cloned to have small storage. # TODO: Should we do this even for non-contiguous tensors? if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE: - return arg.clone(), ref + return arg.clone(), ref, None # Example inputs commonly come from torch.zeros, torch.ones, or torch.full. # These can be represented compactly. for fmt in [torch.contiguous_format, torch.channels_last]: if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item(): return (arg.flatten()[0].clone().expand(*arg.size()), - f"{ref}.contiguous(memory_format={fmt})") + f"{ref}.contiguous(memory_format={fmt})", None) # Prevent big tensors from being bundled by default. # TODO: Provide more useful diagnostics. raise Exception( @@ -370,7 +399,7 @@ def _inflate_expr(arg: T, ref: str) -> Tuple[Union[T, torch.Tensor], str]: f"You probably don't want to bundle this as an input. " ) else: - return arg, ref + return arg, ref, None def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]: methods: List[str] = [] @@ -389,9 +418,37 @@ def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptMo methods.append("get_all_bundled_inputs_for_" + function_name) methods.append("_generate_bundled_inputs_for_" + function_name) attributes.append("_bundled_inputs_deflated_" + function_name) + + bundled_inputs_fn = getattr( + script_module, + f"get_all_bundled_inputs_for_{function_name}" + ) + num_bundled_inputs: int = len(bundled_inputs_fn()) + + # Check inflate helper functions for each function, argument and bundled input + func = getattr(script_module, function_name, None) + for arg_idx in range(len(func.schema.arguments) - 1): + for input_idx in range(num_bundled_inputs): + helper_fn_name = _get_inflate_helper_fn_name( + arg_idx=arg_idx, + input_idx=input_idx, + function_name=function_name + ) + # if the arg has an InflatableArg with fmt_fn, add the helper function name + if hasattr(script_module, helper_fn_name): + methods.append(helper_fn_name) + return (methods, attributes) +def _get_inflate_helper_fn_name( + arg_idx: int, + input_idx: int, + function_name: str, +) -> str: + return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}" + + def bundle_randn(*size, dtype=None): """Generate a tensor that will be inflated with torch.randn.""" diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index b313423426caa..bb0a85982c665 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -37,6 +37,8 @@ BUILD_SPLIT_CUDA = os.getenv('BUILD_SPLIT_CUDA') or (os.path.exists(os.path.join( TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cu{CLIB_EXT}')) and os.path.exists(os.path.join(TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cpp{CLIB_EXT}'))) +SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else () + # Taken directly from python stdlib < 3.9 # See https://github.com/pytorch/pytorch/issues/48617 def _nt_quote_args(args: Optional[List[str]]) -> List[str]: @@ -60,7 +62,7 @@ def _find_cuda_home() -> Optional[str]: which = 'where' if IS_WINDOWS else 'which' with open(os.devnull, 'w') as devnull: nvcc = subprocess.check_output([which, 'nvcc'], - stderr=devnull).decode().rstrip('\r\n') + stderr=devnull).decode(*SUBPROCESS_DECODE_ARGS).rstrip('\r\n') cuda_home = os.path.dirname(os.path.dirname(nvcc)) except Exception: # Guess #3 @@ -90,7 +92,7 @@ def _find_rocm_home() -> Optional[str]: ["which hipcc | xargs readlink -f"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) hipcc, _ = pipe_hipcc.communicate() # this will be either /hip/bin/hipcc or /bin/hipcc - rocm_home = os.path.dirname(os.path.dirname(hipcc.decode().rstrip('\r\n'))) + rocm_home = os.path.dirname(os.path.dirname(hipcc.decode(*SUBPROCESS_DECODE_ARGS).rstrip('\r\n'))) if os.path.basename(rocm_home) == 'hip': rocm_home = os.path.dirname(rocm_home) except Exception: @@ -251,12 +253,12 @@ def check_compiler_ok_for_platform(compiler: str) -> bool: return True which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT) # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'. - compiler_path = os.path.realpath(which.decode().strip()) + compiler_path = os.path.realpath(which.decode(*SUBPROCESS_DECODE_ARGS).strip()) # Check the compiler name if any(name in compiler_path for name in _accepted_compilers_for_platform()): return True # If ccache is used the compiler path is /usr/bin/ccache. Check by -v flag. - version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT).decode() + version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT).decode(*SUBPROCESS_DECODE_ARGS) if sys.platform.startswith('linux'): # Check for 'gcc' or 'g++' pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) @@ -303,11 +305,11 @@ def check_compiler_abi_compatibility(compiler) -> bool: if sys.platform.startswith('linux'): minimum_required_version = MINIMUM_GCC_VERSION versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) - version = versionstr.decode().strip().split('.') + version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.') else: minimum_required_version = MINIMUM_MSVC_VERSION compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) - match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip()) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) version = (0, 0, 0) if match is None else match.groups() except Exception: _, error, _ = sys.exc_info() @@ -767,7 +769,7 @@ def _check_abi(self): def _check_cuda_version(self): if CUDA_HOME: nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc') - cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode() + cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS) cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) if cuda_version is not None: cuda_str_version = cuda_version.group(1) @@ -1727,7 +1729,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # `error` is a CalledProcessError (which has an `ouput`) attribute, but # mypy thinks it's Optional[BaseException] and doesn't narrow if hasattr(error, 'output') and error.output: # type: ignore[union-attr] - message += f": {error.output.decode()}" # type: ignore[union-attr] + message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr] raise RuntimeError(message) from e @@ -1996,7 +1998,7 @@ def sanitize_flags(flags): link_rule = ['rule link'] if IS_WINDOWS: cl_paths = subprocess.check_output(['where', - 'cl']).decode().split('\r\n') + 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n') if len(cl_paths) >= 1: cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') else: diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 1d18b7b030894..ac0c763fe3854 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -11,9 +11,9 @@ from torch.utils.data.dataset import ( ChainDataset, ConcatDataset, + DataChunk, Dataset, Dataset as MapDataPipe, - DataChunk, IterableDataset, IterableDataset as IterDataPipe, Subset, @@ -34,11 +34,14 @@ runtime_validation, runtime_validation_disabled, ) +from torch.utils.data.dataloader_experimental import DataLoader2 +from torch.utils.data import communication __all__ = ['BatchSampler', 'ChainDataset', 'ConcatDataset', 'DataLoader', + 'DataLoader2', 'Dataset', 'DistributedSampler', 'IterDataPipe', @@ -53,6 +56,7 @@ 'WeightedRandomSampler', '_DatasetKind', 'argument_validation', + 'communication', 'functional_datapipe', 'get_worker_info', 'guaranteed_datapipes_determinism', @@ -68,4 +72,3 @@ ################################################################################ # import subpackage ################################################################################ -from torch.utils.data import datapipes diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py new file mode 100644 index 0000000000000..88a395e2bddcf --- /dev/null +++ b/torch/utils/data/communication/__init__.py @@ -0,0 +1,5 @@ +from . import eventloop +from . import iter +from . import messages +from . import protocol +from . import queue diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py new file mode 100644 index 0000000000000..75c44c5192313 --- /dev/null +++ b/torch/utils/data/communication/eventloop.py @@ -0,0 +1,41 @@ +import torch +import threading +import pickle + +from torch.utils.data import IterDataPipe, communication + + +def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): + if isinstance(source_datapipe, IterDataPipe): + pipe_type = communication.iter + protocol_type = communication.protocol.IterDataPipeQueueProtocolServer + else: + raise Exception('Only supports IterDataPipe, got', source_datapipe) + # pipe_type = communication.map + # protocol_type = communication.protocol.MapDataPipeQueueProtocolServer + + torch.set_num_threads(1) + for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True): + pass + + +def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): + req_queue = multiprocessing_ctx.Queue() + res_queue = multiprocessing_ctx.Queue() + process = multiprocessing_ctx.Process( + target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) + return process, req_queue, res_queue + + +def SpawnThreadForDataPipeline(datapipe): + req_queue = communication.queue.ThreadingQueue() + res_queue = communication.queue.ThreadingQueue() + + try: + new_datapipe = pickle.loads(pickle.dumps(datapipe)) + except Exception as e: + raise Exception('Unable to pickle DataPipe to make thread local copy', e) + + process = threading.Thread(target=DataPipeToQueuesLoop, args=( + new_datapipe, req_queue, res_queue), daemon=True) + return process, req_queue, res_queue, new_datapipe diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py new file mode 100644 index 0000000000000..594a466295a5f --- /dev/null +++ b/torch/utils/data/communication/iter.py @@ -0,0 +1,173 @@ +import time +import types + +from torch.utils.data import IterDataPipe, communication + +DEFAULT_NON_BLOCKING_SLEEP = 0.001 + + +def default_not_available_hook(): + time.sleep(DEFAULT_NON_BLOCKING_SLEEP) + + +class NotAvailable(Exception): + pass + + +class InvalidStateResetRequired(Exception): + """ + Returned by DataPipe when it is expecting to get reset request, + for example RouterDataPipe expecting all workers to request reset' + """ + pass + + +class NonBlocking(IterDataPipe): + not_available_hook = default_not_available_hook + + def __iter__(self): + self.reset_iterator() + return self + + def __next__(self): + while True: + try: + return self.nonblocking_next() + except StopIteration: + raise StopIteration + except NotAvailable: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + + def nonblocking_next(self): + raise NotImplementedError( + "nonblocking_next is not implemented for %s" % self.__class__) + + def reset_iterator(self): + raise NotImplementedError( + "reset_iterator is not implemented for %s" % self.__class__) + + @staticmethod + def register_not_available_hook(hook_function): + NonBlocking.not_available_hook = hook_function + + +def EnsureNonBlockingDataPipe(validated_datapipe): + if not isinstance(validated_datapipe, IterDataPipe): + raise Exception('Not Iterable DataPipe ' + + str(validated_datapipe.__class__)) + if isinstance(validated_datapipe, NonBlocking): + return validated_datapipe + if not hasattr(validated_datapipe, '_as_iterator'): + validated_datapipe._as_iterator = None # type: ignore[attr-defined] + if not hasattr(validated_datapipe, 'nonblocking_next'): + def nonblocking_next(self): + if self._as_iterator is None: + self._as_iterator = iter(self) + return next(self._as_iterator) + validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] + nonblocking_next, validated_datapipe) + if not hasattr(validated_datapipe, 'reset_iterator'): + def reset_iterator(self): + self._as_iterator = None + validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] + reset_iterator, validated_datapipe) + return validated_datapipe + + +def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): + """ + Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue + If raise_stop is true, raises exception when StopIteration received from the source_datapipe + """ + if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): + raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) + source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) + forever = True + while forever: + + try: + # Non-blocking call is Extremely slow here for python.mp, need to figureout good workaround + request = protocol.get_new_request(block=blocking_request_get) + except communication.protocol.EmptyQueue: + yield True + continue + + if isinstance(request, communication.messages.ResetIteratorRequest): + source_datapipe.reset_iterator() + protocol.response_reset() + + elif isinstance(request, communication.messages.TerminateRequest): + forever = False + protocol.response_terminate() + + elif isinstance(request, communication.messages.GetNextRequest): + while forever: + try: + value = source_datapipe.nonblocking_next() + except NotAvailable: + yield True + continue + except StopIteration: + protocol.response_stop() + if full_stop: + forever = False + else: + yield True + break + except InvalidStateResetRequired: + protocol.response_invalid() + if full_stop: + forever = False + else: + yield True + break + protocol.response_next(value) + yield True # Returns control + break + else: + raise Exception('Unrecognized type of request received', request) + + +class QueueWrapper(NonBlocking): + """ + Creates iter.DataPipe which reads data from the DataLoader.Queue + """ + + def __init__(self, protocol, response_wait_time=0.00001): + if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): + raise Exception('Got', protocol) + + self.protocol = protocol + self.counter = 0 + self._stop_iteration = False + self._response_wait_time = response_wait_time + + def reset_iterator(self): + self._stop_iteration = False + self.counter = 0 + self.protocol.request_reset() + while True: + try: + self.protocol.get_response_reset() + break + except communication.protocol.EmptyQueue: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + + def nonblocking_next(self): + if self._stop_iteration: + raise Exception( + '`next` or `nonblocking_next` called after receiving StopIteration') + if self.protocol.can_take_request(): + self.protocol.request_next() + try: + response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) + except communication.protocol.EmptyQueue: + raise NotAvailable + if isinstance(response, communication.messages.StopIterationResponse): + self._stop_iteration = True + raise StopIteration + if isinstance(response, communication.messages.InvalidStateResponse): + raise NotAvailable + return response.value diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py new file mode 100644 index 0000000000000..449cf23cfc01c --- /dev/null +++ b/torch/utils/data/communication/messages.py @@ -0,0 +1,75 @@ +class DataLoaderQueueMessage(object): + pass + + +class Request(DataLoaderQueueMessage): + pass + + +class Response(DataLoaderQueueMessage): + pass + + +class ResetIteratorRequest(Request): + pass + + +class ResetIteratorResponse(Response): + pass + + +class TerminateRequest(Request): + pass + + +class TerminateResponse(Response): + pass + + +class LenRequest(Request): + pass + + +class LenResponse(Response): + __slots__ = ('len') + + def __init__(self, len): + self.len = len + + +class GetItemRequest(Request): + __slots__ = ('key') + + def __init__(self, key): + self.key = key + + +class GetItemResponse(Response): + __slots__ = ('key', 'value') + + def __init__(self, key, value): + self.key = key + self.value = value + + +class GetNextRequest(Request): + pass + + +class GetNextResponse(Response): + __slots__ = ('value') + + def __init__(self, value): + self.value = value + + +class StopIterationResponse(Response): + pass + + +class InvalidStateResponse(Response): + """ + Returned by DataPipe when it is expecting to get reset request, + for example RouterDataPipe expecting all workers to request reset' + """ + pass diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py new file mode 100644 index 0000000000000..68ff335714d3f --- /dev/null +++ b/torch/utils/data/communication/protocol.py @@ -0,0 +1,159 @@ +from torch.utils.data import communication + + +class Protocol(object): + __slots__ = ('request_queue', 'response_queue') + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + + +class ProtocolClient(Protocol): + """ + ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. + """ + _req_sent = None + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + self._req_sent = None + + def can_take_request(self): + return self._req_sent is None + + def waiting_for_response(self): + return self._req_sent is not None + + def request_sent(self, request=True): + if not self.can_take_request(): + raise Exception('Protocol only supports one request in the Queue') + self._req_sent = request + + def request_served(self, result=None): + if not self.waiting_for_response(): + raise Exception( + 'Expected no peding requests, but something got served', result) + self._req_sent = None + + +class ProtocolServer(Protocol): + """ + ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. + """ + _req_received = None + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + self._req_received = None + + def have_pending_request(self): + return self._req_received is not None + + def get_new_request(self, block=False): + if self.have_pending_request(): + raise Exception( + 'Trying to get next request, while having one unserved') + try: + response = self.request_queue.get(block=block) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self._req_received = response + return response + + # TODO: Validate supported requests + + def response_reset(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.ResetIteratorRequest): + raise Exception( + "Replaying with reset status to other type of message") + self.response_queue.put(communication.messages.ResetIteratorResponse()) + self._req_received = None + + def response_next(self, value): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.GetNextResponse(value)) + self._req_received = None + + def response_stop(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.StopIterationResponse()) + self._req_received = None + + def response_invalid(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.InvalidStateResponse()) + self._req_received = None + + def response_terminate(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.TerminateRequest): + raise Exception( + "Replaying with terminate status to other type of message") + self.response_queue.put(communication.messages.TerminateResponse()) + self._req_received = None + + +class MapDataPipeQueueProtocolClient(ProtocolClient): + pass + + +class MapDataPipeQueueProtocolServer(ProtocolServer): + pass + + +class EmptyQueue(Exception): + pass + + +class IterDataPipeQueueProtocolServer(ProtocolServer): + pass + + +class IterDataPipeQueueProtocolClient(ProtocolClient): + def request_reset(self): + if not self.can_take_request(): + raise Exception( + 'Can not reset while we are still waiting response for previous request') + request = communication.messages.ResetIteratorRequest() + self.request_queue.put(request) + self.request_sent(request) + + def request_next(self): + if not self.can_take_request(): + raise Exception( + 'Can not request next item while we are still waiting response for previous request') + request = communication.messages.GetNextRequest() + self.request_queue.put(request) + self.request_sent(request) + + def get_response_reset(self, block=False): + try: + response = self.response_queue.get(block=block) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self.request_served(response) + + if not isinstance(response, communication.messages.ResetIteratorResponse): + raise Exception('Invalid response received') + + def get_response_next(self, block=False, timeout=None): + if not self.waiting_for_response(): + raise Exception( + 'Can not expect any response without submitted request') + try: + response = self.response_queue.get(block=block, timeout=timeout) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self.request_served(response) + + # TODO(VitalyFedyunin): Add possible response types validation here + return response diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py new file mode 100644 index 0000000000000..7717697b0f75d --- /dev/null +++ b/torch/utils/data/communication/queue.py @@ -0,0 +1,50 @@ +import threading +import time + +class LocalQueue(): + ops = 0 + stored = 0 + uid = 0 + empty = 0 + + def __init__(self, name='unnamed'): + self.items = [] + self.name = name + self.uid = LocalQueue.uid + LocalQueue.uid += 1 + + def put(self, item, block=True): + LocalQueue.ops += 1 + LocalQueue.stored += 1 + self.items.append(item) + + def get(self, block=True, timeout=0): + # TODO(VitalyFedyunin): Add support of block and timeout arguments + LocalQueue.ops += 1 + if not len(self.items): + LocalQueue.empty += 1 + raise Exception('LocalQueue is empty') + LocalQueue.stored -= 1 + return self.items.pop() + + +class ThreadingQueue(): + def __init__(self, name='unnamed'): + self.lock = threading.Lock() + self.items = [] + self.name = name + + def put(self, item, block=True): + with self.lock: + self.items.append(item) + + def get(self, block=True, timeout=0): + # TODO(VitalyFedyunin): Add support of block and timeout arguments + while True: + with self.lock: + if len(self.items) > 0: + return self.items.pop() + if not block: + raise Exception("Not available") + # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue + time.sleep(0.000001) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index c85296f8f807f..0f46ad283ea5a 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -160,8 +160,8 @@ class DataLoader(Generic[T_co]): __initialized = False def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, - shuffle: bool = False, sampler: Optional[Sampler[int]] = None, - batch_sampler: Optional[Sampler[Sequence[int]]] = None, + shuffle: bool = False, sampler: Optional[Sampler] = None, + batch_sampler: Optional[Sampler[Sequence]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py new file mode 100644 index 0000000000000..a74c75cd75122 --- /dev/null +++ b/torch/utils/data/dataloader_experimental.py @@ -0,0 +1,157 @@ + +import functools +import time + +from typing import Any, List + +import torch.utils.data.backward_compatibility + +import torch.utils.data.sharding +from torch.utils.data import DataLoader, IterDataPipe, communication +from torch.utils.data.datapipes.iter import IterableWrapper + +class _ThreadingDataLoader2: + + def __init__(self, datapipe, num_workers=0, collate_fn=None): + self.threads = [] + self.datapipes = [] + self.collate_fn = collate_fn + for worker_id in range(num_workers): + (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) + torch.utils.data.sharding.apply_sharding(thread_localdatapipe, num_workers, worker_id) + thread.start() + self.threads.append((thread, req_queue, res_queue)) + local_datapipe = communication.iter.QueueWrapper( + communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) + self.datapipes.append(local_datapipe) + + def __iter__(self): + not_available = False + forever = True + exclude_datapipes: List[Any] = [] + while len(exclude_datapipes) < len(self.datapipes): + for dp in self.datapipes: + if dp not in exclude_datapipes: + try: + value = dp.nonblocking_next() + yield value + except StopIteration: + exclude_datapipes.append(dp) + except communication.iter.NotAvailable: + not_available = True + if not_available: + time.sleep(0.001) + + def __del__(self): + self._cleanup_all_threads() + + def _cleanup_all_threads(self): + def clean_me(thread, req_queue, res_queue): + req_queue.put(communication.messages.TerminateRequest()) + _ = res_queue.get() + thread.join() + + for thread, req_queue, res_queue in self.threads: + clean_me(thread, req_queue, res_queue) + + +class DataLoader2: + def __new__(cls, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + *, + prefetch_factor=2, + persistent_workers=False, + batch_outside_worker=False, + parallelism_mode='mp'): + if isinstance(dataset, IterDataPipe): + data_loader: Any = None + if batch_sampler is not None: + raise Exception( + 'batch_sampler is not yet supported by DataPipes') + if sampler is not None: + raise Exception( + 'sampler is not yet supported by DataPipes') + datapipe = dataset + if shuffle: + datapipe = datapipe.shuffle() + if batch_outside_worker and pin_memory: + raise Exception( + 'pin_memory is not yet compatible with batch_outside_worker') + if not batch_outside_worker: + if batch_size is not None: + datapipe = datapipe.batch(batch_size, drop_last=drop_last) + if collate_fn is None: + collate_fn = torch.utils.data._utils.collate.default_collate + if parallelism_mode == 'mp' or num_workers == 0: + def sharding_worker_init_fn(worker_init_fn, worker_id): + if worker_init_fn is not None: + worker_init_fn(worker_id) + torch.utils.data.backward_compatibility.worker_init_fn( + worker_id) + + my_worker_init_fn = functools.partial( + sharding_worker_init_fn, worker_init_fn) + + data_loader = DataLoader(datapipe, + batch_size=None, # Replaced by .batch DataPipe + shuffle=False, # Replaced by .shuffle DataPipe + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, # Replaced by .batch DataPipe + timeout=timeout, + worker_init_fn=my_worker_init_fn, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + elif parallelism_mode == 'thread': + if collate_fn is not None and not batch_outside_worker: + datapipe = datapipe.map(collate_fn) + if pin_memory: + raise Exception( + 'pin_memory is not yet supported by DataPipes with Threading') + if worker_init_fn is not None: + raise Exception( + 'worker_init_fn is not yet supported by DataPipes with Threading') + data_loader = _ThreadingDataLoader2(datapipe, + num_workers=num_workers, + collate_fn=collate_fn) + else: + raise Exception('Unsupported parallelism mode', parallelism_mode) + if not batch_outside_worker: + return data_loader + else: + if collate_fn is None: + collate_fn = torch.utils.data._utils.collate.default_collate + datapipe = IterableWrapper(data_loader).batch( + batch_size, drop_last=drop_last).map(collate_fn) + return datapipe + else: + if parallelism_mode != 'thread': + raise Exception( + 'thread parallelism mode is not supported for old DataSets') + + return DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) diff --git a/torch/utils/data/datapipes/README.md b/torch/utils/data/datapipes/README.md new file mode 100644 index 0000000000000..69cd56d3cfbd1 --- /dev/null +++ b/torch/utils/data/datapipes/README.md @@ -0,0 +1,103 @@ +The [`datapipes`](https://github.com/pytorch/pytorch/tree/master/torch/utils/data/datapipes) folder holds the implementation of the `IterDataPipe` and `MapDataPipe`. + +This document serves as an entry point for DataPipe implementation. + +## Implementing DataPipe +For the sake of an example, let us implement an `IterDataPipe` to apply a callable over data under [`iter`](https://github.com/pytorch/pytorch/tree/master/torch/utils/data/datapipes/iter). +For `MapDataPipe`, please take reference from files in [map](https://github.com/pytorch/pytorch/tree/master/torch/utils/data/datapipes/map) folder and implement the corresponding `__getitem__` method. + +### Naming +The naming convention for DataPipe is Operation-er and with suffix of `IterDataPipe` because each DataPipe behaves like a container to apply the operation to data yielded from the source DataPipe. +And, when importing the DataPipe into `iter` module under `datapipes`, each DataPipe will be aliased as Op-er without the suffix of `IterDataPipe`. +Please check [`__init__.py`](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/datapipes/iter/__init__.py) in `iter` module for how we aliasing each DataPipe class. +Like the example of `IterDataPipe` to map a function, we are going to name it as `MapperIterDataPipe` and alias it as `iter.Mapper` under `datapipes`. + +### Constructor +As DataSet now constructed by a stack of DataPipe-s, each DataPipe normally takes a source DataPipe as the first argument. +```py +class MapperIterDataPipe(IterDataPipe): + def __init__(self, dp, fn): + super().__init__() + self.dp = dp + self.fn = fn +``` +Note: Avoid loading data from the source DataPipe in `__init__` function, in order to support lazy data loading and save memory. + +### Iterator +For `IterDataPipe`, an `__iter__` function is needed to consume data from the source `IterDataPipe` then apply operation over the data before yield. +```py +class MapperIterDataPipe(IterDataPipe): + ... + + def __iter__(self): + for d in self.dp: + yield self.fn(d) +``` + +### Length +In the most common cases, as the example of `MapperIterDataPipe` above, the `__len__` method of DataPipe should return the length of source DataPipe. +```py +class MapperIterDataPipe(IterDataPipe): + ... + + def __len__(self): + return len(self.dp) +``` +Note that `__len__` method is optional for `IterDataPipe`. +Like `CSVParserIterDataPipe` in the [Using DataPipe sector](#using-datapipe), `__len__` is not implemented because the size of each file streams is unknown for us before loading it. + +Besides, in some special cases, `__len__` method can be provided, but it would either return an integer length or raise Error depending on the arguments of DataPipe. +And, the Error is required to be `TypeError` to support Python's build-in functions like `list(dp)`. +Please check NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] for detailed reason in PyTorch. + +### Registering DataPipe with functional API +Each DataPipe can be registered to support functional API using the decorator `functional_datapipe`. +```py +@functional_datapipe("map") +class MapperIterDataPipe(IterDataPipe): + ... +``` +Then, the stack of DataPipe can be constructed in functional-programming manner. +```py +>>> import torch.utils.data.datapipes as dp +>>> datapipes1 = dp.iter.FileLoader(['a.file', 'b.file']).map(fn=decoder).shuffle().batch(2) + +>>> datapipes2 = dp.iter.FileLoader(['a.file', 'b.file']) +>>> datapipes2 = dp.iter.Mapper(datapipes2) +>>> datapipes2 = dp.iter.Shuffler(datapipes2) +>>> datapipes2 = dp.iter.Batcher(datapipes2, 2) +``` +In the above example, `datapipes1` and `datapipes2` represent the exact same stack of `IterDataPipe`-s. + +## Using DataPipe +For example, we want to load data from CSV files with the following data pipeline: +- List all csv files +- Load csv files +- Parse csv file and yield rows + +To support the above pipeline, `CSVParser` is registered as `parse_csv_files` to consume file streams and expand them as rows. +```py +@functional_datapipe("parse_csv_files") +class CSVParserIterDataPipe(IterDataPipe): + def __init__(self, dp, **fmtparams): + self.dp = dp + self.fmtparams = fmtparams + + def __iter__(self): + for filename, stream in self.dp: + reader = csv.reader(stream, **self.fmtparams) + for row in reader: + yield filename, row +``` +Then, the pipeline can be assembled as following: +```py +>>> import torch.utils.data.datapipes as dp + +>>> FOLDER = 'path/2/csv/folder' +>>> datapipe = dp.iter.FileLister([FOLDER]).filter(fn=lambda filename: filename.endswith('.csv')) +>>> datapipe = dp.iter.FileLoader(datapipe, mode='rt') +>>> datapipe = datapipe.parse_csv_files(delimiter=' ') + +>>> for d in datapipe: # Start loading data +... pass +``` diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py index 0bcfdc44c31cf..26d715d310234 100644 --- a/torch/utils/data/datapipes/iter/__init__.py +++ b/torch/utils/data/datapipes/iter/__init__.py @@ -1,38 +1,35 @@ from torch.utils.data.datapipes.iter.callable import ( - CollateIterDataPipe as Collate, - MapIterDataPipe as Map, - TransformsIterDataPipe as Transforms, + CollatorIterDataPipe as Collator, + MapperIterDataPipe as Mapper, ) from torch.utils.data.datapipes.iter.combinatorics import ( SamplerIterDataPipe as Sampler, - ShuffleIterDataPipe as Shuffle, + ShufflerIterDataPipe as Shuffler, ) from torch.utils.data.datapipes.iter.combining import ( - ConcatIterDataPipe as Concat, - ZipIterDataPipe as Zip, + ConcaterIterDataPipe as Concater, + DemultiplexerIterDataPipe as Demultiplexer, + ForkerIterDataPipe as Forker, + MultiplexerIterDataPipe as Multiplexer, + ZipperIterDataPipe as Zipper, +) +from torch.utils.data.datapipes.iter.filelister import ( + FileListerIterDataPipe as FileLister, +) +from torch.utils.data.datapipes.iter.fileloader import ( + FileLoaderIterDataPipe as FileLoader, ) from torch.utils.data.datapipes.iter.grouping import ( - BatchIterDataPipe as Batch, + BatcherIterDataPipe as Batcher, BucketBatcherIterDataPipe as BucketBatcher, - GroupByKeyIterDataPipe as GroupByKey, + GrouperIterDataPipe as Grouper, + UnBatcherIterDataPipe as UnBatcher, ) from torch.utils.data.datapipes.iter.httpreader import ( HTTPReaderIterDataPipe as HttpReader, ) -from torch.utils.data.datapipes.iter.listdirfiles import ( - ListDirFilesIterDataPipe as ListDirFiles, -) -from torch.utils.data.datapipes.iter.loadfilesfromdisk import ( - LoadFilesFromDiskIterDataPipe as LoadFilesFromDisk, -) -from torch.utils.data.datapipes.iter.readfilesfromtar import ( - ReadFilesFromTarIterDataPipe as ReadFilesFromTar, -) -from torch.utils.data.datapipes.iter.readfilesfromzip import ( - ReadFilesFromZipIterDataPipe as ReadFilesFromZip, -) -from torch.utils.data.datapipes.iter.readlinesfromfile import ( - ReadLinesFromFileIterDataPipe as ReadLinesFromFile, +from torch.utils.data.datapipes.iter.linereader import ( + LineReaderIterDataPipe as LineReader, ) from torch.utils.data.datapipes.iter.routeddecoder import ( RoutedDecoderIterDataPipe as RoutedDecoder, @@ -40,29 +37,42 @@ from torch.utils.data.datapipes.iter.selecting import ( FilterIterDataPipe as Filter, ) -from torch.utils.data.datapipes.iter.tobytes import ( - ToBytesIterDataPipe as ToBytes, +from torch.utils.data.datapipes.iter.streamreader import ( + StreamReaderIterDataPipe as StreamReader, +) +from torch.utils.data.datapipes.iter.tararchivereader import ( + TarArchiveReaderIterDataPipe as TarArchiveReader, +) +from torch.utils.data.datapipes.iter.ziparchivereader import ( + ZipArchiveReaderIterDataPipe as ZipArchiveReader, +) +from torch.utils.data.datapipes.iter.utils import ( + IterableWrapperIterDataPipe as IterableWrapper, ) -__all__ = ['Batch', +__all__ = ['Batcher', 'BucketBatcher', - 'Collate', - 'Concat', + 'Collator', + 'Concater', + 'Demultiplexer', + 'FileLister', + 'FileLoader', 'Filter', - 'GroupByKey', + 'Forker', + 'Grouper', 'HttpReader', - 'ListDirFiles', - 'LoadFilesFromDisk', - 'Map', - 'ReadFilesFromTar', - 'ReadFilesFromZip', - 'ReadLinesFromFile', + 'IterableWrapper', + 'LineReader', + 'Mapper', + 'Multiplexer', 'RoutedDecoder', 'Sampler', - 'Shuffle', - 'ToBytes', - 'Transforms', - 'Zip'] + 'Shuffler', + 'StreamReader', + 'TarArchiveReader', + 'UnBatcher', + 'ZipArchiveReader', + 'Zipper'] # Please keep this list sorted assert __all__ == sorted(__all__) diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index cc0f9e13b3adf..2c5ca3d024392 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,5 +1,4 @@ import warnings -import torch.nn as nn from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar @@ -26,20 +25,21 @@ def default_fn(data): @functional_datapipe('map') -class MapIterDataPipe(IterDataPipe[T_co]): - r""" :class:`MapIterDataPipe`. +class MapperIterDataPipe(IterDataPipe[T_co]): + r""" :class:`MapperIterDataPipe`. Iterable DataPipe to run a function over each item from the source DataPipe. The function can be any regular python function or partial object. Lambda function is not recommended as it is not supported by pickle. - args: + + Args: datapipe: Source Iterable DataPipe fn: Function called over each item fn_args: Positional arguments for `fn` fn_kwargs: Keyword arguments for `fn` - nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0) - This also accepts -1 as input to apply the function to the lowest nesting level. It currently doesn't support - argument < -1. + nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0). + This also accepts -1 as input to apply the function to the lowest nesting level. It currently doesn't support + argument < -1. """ datapipe: IterDataPipe fn: Callable @@ -108,15 +108,16 @@ def __setstate__(self, state): @functional_datapipe('collate') -class CollateIterDataPipe(MapIterDataPipe): - r""" :class:`CollateIterDataPipe`. +class CollatorIterDataPipe(MapperIterDataPipe): + r""" :class:`CollatorIterDataPipe`. Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`, or customized Data Structure by collate_fn. - args: + + Args: datapipe: Iterable DataPipe being collated collate_fn: Customized collate function to collect and combine data or a batch of data. - Default function collates to Tensor(s) based on data type. + Default function collates to Tensor(s) based on data type. fn_args: Positional arguments for `collate_fn` fn_kwargs: Keyword arguments for `collate_fn` @@ -153,35 +154,3 @@ def __init__(self, fn_kwargs: Optional[Dict] = None, ) -> None: super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs) - - -@functional_datapipe('legacy_transforms') -class TransformsIterDataPipe(MapIterDataPipe): - r""" :class:`TransformsIterDataPipe`. - - Iterable DataPipe to use transform(s) from torchvision or torchaudio to transform - data from datapipe. - args: - datapipe: Iterable DataPipe being transformed - transforms: A transform or a sequence of transforms from torchvision or torchaudio. - """ - - def __init__(self, - datapipe: IterDataPipe, - transforms: Callable, - ) -> None: - # Type checking for transforms - transforms_types: Tuple = (nn.Module, ) - try: - # Specific types of transforms other than `nn.Module` from torchvision - import torchvision.transforms as tsfm - transforms_types += (tsfm.Compose, tsfm.RandomChoice, tsfm.RandomOrder, - tsfm.ToPILImage, tsfm.ToTensor, tsfm.Lambda) - except ImportError: - pass - - if not isinstance(transforms, transforms_types): - raise TypeError("`transforms` are required to be a callable from " - "torchvision.transforms or torchaudio.transforms") - - super().__init__(datapipe, fn=transforms) diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index a8b1e3d9737fa..5e17a3ef56c33 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -1,7 +1,7 @@ import random from torch.utils.data import IterDataPipe, Sampler, SequentialSampler, functional_datapipe -from typing import TypeVar, Type, Iterator, Sized, Optional, Tuple, Dict, List +from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar T_co = TypeVar('T_co', covariant=True) @@ -10,10 +10,11 @@ class SamplerIterDataPipe(IterDataPipe[T_co]): r""" :class:`SamplerIterDataPipe`. Iterable DataPipe to generate sample elements. - args: - datapipe: IterDataPipe sampled from + + Args: + datapipe: IterDataPipe to sample from sampler: Sampler class to genereate sample elements from input DataPipe. - Default is :class:`SequentialSampler` for IterDataPipe + Default is :class:`SequentialSampler` for IterDataPipe """ datapipe: IterDataPipe sampler: Sampler @@ -44,8 +45,8 @@ def __len__(self) -> int: @functional_datapipe('shuffle') -class ShuffleIterDataPipe(IterDataPipe[T_co]): - r""" :class:`ShuffleIterDataPipe` +class ShufflerIterDataPipe(IterDataPipe[T_co]): + r""" :class:`ShufflerIterDataPipe` Iterable DataPipe to shuffle the input DataPipe with a buffer. The buffer with `buffer_size` is filled with elements from the datapipe first. Then, @@ -63,7 +64,7 @@ class ShuffleIterDataPipe(IterDataPipe[T_co]): mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed for each worker process. - args: + Args: datapipe: The IterDataPipe being shuffled buffer_size: The buffer size for shuffling (default to 10000) unbatch_level: Specifies if it necessary to unbatch source data before @@ -71,7 +72,6 @@ class ShuffleIterDataPipe(IterDataPipe[T_co]): """ datapipe: IterDataPipe[T_co] buffer_size: int - _buffer: List[T_co] def __init__(self, datapipe: IterDataPipe[T_co], @@ -86,24 +86,24 @@ def __init__(self, else: self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) self.buffer_size = buffer_size - self._buffer = [] - def buffer_replace(self, x): - idx = random.randint(0, self.buffer_size - 1) - val = self._buffer[idx] - self._buffer[idx] = x + @staticmethod + def buffer_replace(buffer, x): + idx = random.randint(0, len(buffer) - 1) + val = buffer[idx] + buffer[idx] = x return val def __iter__(self) -> Iterator[T_co]: - # TODO: Buffer is global, should be per __iter__ !!! + buffer: List[T_co] = [] for x in self.datapipe: - if len(self._buffer) == self.buffer_size: - yield self.buffer_replace(x) + if len(buffer) == self.buffer_size: + yield ShufflerIterDataPipe.buffer_replace(buffer, x) else: - self._buffer.append(x) - random.shuffle(self._buffer) - while self._buffer: - yield self._buffer.pop() + buffer.append(x) + random.shuffle(buffer) + while buffer: + yield buffer.pop() def __len__(self) -> int: if isinstance(self.datapipe, Sized): diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 0693b1f0ad6de..ed1256fa1e757 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -1,17 +1,19 @@ -import functools +import warnings from torch.utils.data import IterDataPipe, functional_datapipe -from typing import Iterator, Optional, Sized, Tuple, TypeVar +from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque +from collections import deque T_co = TypeVar('T_co', covariant=True) @functional_datapipe('concat') -class ConcatIterDataPipe(IterDataPipe): - r""" :class:`ConcatIterDataPipe`. +class ConcaterIterDataPipe(IterDataPipe): + r""" :class:`ConcaterIterDataPipe`. Iterable DataPipe to concatenate multiple Iterable DataPipes. - args: + + Args: datapipes: Iterable DataPipes being concatenated """ datapipes: Tuple[IterDataPipe] @@ -45,6 +47,7 @@ def __len__(self) -> int: # This is fake class to show API, going to be replaced by the copy from torchdata # TODO(VitalyFedyunin): Replace with valid version, documentation and tests class IterateBuffer(IterDataPipe): + def __init__(self, buffer): self.buffer = buffer @@ -54,55 +57,256 @@ def __iter__(self): @functional_datapipe('fork') -class ForkIterDataPipe(IterDataPipe): +class ForkerIterDataPipe(IterDataPipe): + r""" :class:`ForkerIterDataPipe`. + + Iterable DataPipe to create multiple instances of the same Iterable DataPipe. + + Args: + datapipe: Iterable DataPipe being copied + num_instances: number of instances of the datapipe to create + buffer_size: this restricts how far ahead the leading child DataPipe + can read relative to the slowest child DataPipe + """ + def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): + container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size) + return [_ChildDataPipe(container, i) for i in range(num_instances)] + + +class _ForkerIterDataPipe(IterDataPipe): + r""" :class:`_ForkerIterDataPipe`. + + Container to hold instance-specific information on behalf of ForkerIterDataPipe. It tracks + the state of its child DataPipes, maintains the buffer, and yields the next value + as requested by the child DataPipes. + """ + def __init__(self, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): + self.main_datapipe = datapipe + self._datapipe_iterator: Optional[Iterator[Any]] = None + self.num_instances = num_instances + self.buffer: Deque = deque() + self.buffer_size = buffer_size + self.child_pointers = [0] * num_instances # Indicate the indices of the next element to get + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr: Optional[int] = None + + def __len__(self): + return len(self.main_datapipe) + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None: + self._datapipe_iterator = iter(self.main_datapipe) + while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr: + if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr: + self.leading_ptr = self.child_pointers[instance_id] + if self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size: + raise BufferError("ForkerIterDataPipe buffer overflow," + + f"buffer size {self.buffer_size} is insufficient.") + try: + self.buffer.append(next(self._datapipe_iterator)) + self.child_pointers[instance_id] += 1 + yield self.buffer[-1] + except StopIteration: + self.end_ptr = self.leading_ptr + else: # Child pointer is slower than or equal to the leading_ptr + buffer_index = self.child_pointers[instance_id] - self.slowest_ptr + return_val = self.buffer[buffer_index] + self.child_pointers[instance_id] += 1 + if self.child_pointers[instance_id] - 1 == self.slowest_ptr: + new_min = min(self.child_pointers) # Can optimize by avoiding the call to min() + if self.slowest_ptr < new_min: + self.slowest_ptr = new_min + self.buffer.popleft() + yield return_val + + def is_instance_started(self, instance_id: int) -> bool: + return self.child_pointers[instance_id] != 0 - def __new__(cls, datapipe, instances): - result = [] - buffer = list(datapipe) - return [IterateBuffer(buffer) for i in range(instances)] + def is_every_instance_exhausted(self) -> bool: + return all(self.end_ptr == ptr for ptr in self.child_pointers) + + def reset(self): + self._datapipe_iterator = iter(self.main_datapipe) + self.buffer = deque() + self.child_pointers = [0] * self.num_instances + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr = None + +class _ChildDataPipe(IterDataPipe): + r""" :class:`_ChildDataPipe`. + + Iteratable Datapipe that is a child of a main DataPipe. The instance of this class + will pass its instance_id to get the next value from its main DataPipe. + + Args: + main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' + instance_id: integer identifier of this instance + """ + def __init__(self, main_datapipe, instance_id: int): + required_attrs = ["get_next_element_by_instance", "is_instance_started", "is_every_instance_exhausted", "reset"] + required_ops = [getattr(main_datapipe, attr) for attr in required_attrs] + if any(not callable(op) for op in required_ops): + raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.") + self.main_datapipe = main_datapipe + self.instance_id = instance_id + + def __iter__(self): + if self.main_datapipe.is_instance_started(self.instance_id): # Only reset if the DataPipe started to read + if not self.main_datapipe.is_every_instance_exhausted(): + warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting " + "the buffer and each child DataPipe will read from the start again.", UserWarning) + self.main_datapipe.reset() + # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called + return self.get_generator_by_instance(self.instance_id) + + def __len__(self): + return len(self.main_datapipe) + + def get_generator_by_instance(self, instance_id: int): + yield from self.main_datapipe.get_next_element_by_instance(self.instance_id) @functional_datapipe('demux') class DemultiplexerIterDataPipe(IterDataPipe): + r""" :class:`DemultiplexerIterDataPipe`. - def __new__(cls, datapipe, instances, classifier_fn): - result = [] - buffer = list(datapipe) + Iterable DataPipe to split the input DataPipe into multiple child DataPipes, using the given + classification function. A list of the child DataPipes is returned from this operation. + + Args: + datapipe: Iterable DataPipe being filtered + num_instances: number of instances of the DataPipe to create + classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None + drop_none: defaults to False, if True, the function will skip over elements classified as None + buffer_size: this defines the maximum number of inputs that the buffer can hold across all child + DataPipes while waiting for their values to be yielded + """ + def __new__(cls, datapipe: IterDataPipe, num_instances: int, + classifier_fn: Callable[[T_co], int], drop_none: bool = False, buffer_size: int = 1000): + container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) + return [_ChildDataPipe(container, i) for i in range(num_instances)] - def filter_fn(classifier_fn, i, x): - return classifier_fn(x) == i - return [IterateBuffer(buffer).filter(functools.partial(filter_fn, classifier_fn, i)) for i in range(instances)] + +class _DemultiplexerIterDataPipe(IterDataPipe): + r""" :class:`_DemultiplexerIterDataPipe`. + + Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. It tracks + the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value + as requested by the child DataPipes. + """ + + def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int, + classifier_fn: Callable[[T_co], int], drop_none: bool, buffer_size: int): + self.main_datapipe = datapipe + self._datapipe_iterator: Optional[Iterator[Any]] = None + self.num_instances = num_instances + self.max_buffer_size = buffer_size + self.current_buffer_usage = 0 + self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)] + self.instance_started: List[bool] = [False] * num_instances + self.classifier_fn = classifier_fn + self.drop_none = drop_none + self.main_datapipe_exhausted = False + + def _find_next(self, instance_id: int) -> T_co: + while True: + if self._datapipe_iterator is None: + raise ValueError("_datapipe_iterator has not been set, likely because this private method is called directly " + "without invoking get_next_element_by_instance() first.") + value = next(self._datapipe_iterator) + classification = self.classifier_fn(value) + if classification is None and self.drop_none: + continue + if classification is None or classification >= self.num_instances or classification < 0: + raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + + f"{classification} is returned.") + if classification == instance_id: + return value + self.child_buffers[classification].append(value) + self.current_buffer_usage += 1 + if self.current_buffer_usage > self.max_buffer_size: + raise BufferError( + f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.max_buffer_size} is insufficient.") + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None: + self._datapipe_iterator = iter(self.main_datapipe) + stop = False + self.instance_started[instance_id] = True + while not stop: + if self.child_buffers[instance_id]: + self.current_buffer_usage -= 1 + yield self.child_buffers[instance_id].popleft() + else: + try: + yield self._find_next(instance_id) + except StopIteration: + stop = True + self.main_datapipe_exhausted = True + + def is_instance_started(self, instance_id: int) -> bool: + return self.instance_started[instance_id] + + def is_every_instance_exhausted(self) -> bool: + return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers) + + def reset(self): + self._datapipe_iterator = iter(self.main_datapipe) + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self.instance_started = [False] * self.num_instances + self.main_datapipe_exhausted = False @functional_datapipe('mux') class MultiplexerIterDataPipe(IterDataPipe): + r""" :class:`MultiplexerIterDataPipe`. + Iterable DataPipe that yields one element at a time from each input Iterable DataPipe + (i.e. one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on). It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted + """ def __init__(self, *datapipes): self.datapipes = datapipes + self.length: Optional[int] = None def __iter__(self): iterators = [iter(x) for x in self.datapipes] - finished = {} - had_more = True - while had_more: - had_more = False + finished: Set[int] = set() + while len(finished) < len(iterators): for i in range(len(iterators)): if i not in finished: try: - value = iterators[i].__next__() - had_more = True + value = next(iterators[i]) yield value except StopIteration: - finished[i] = 1 + finished.add(i) + + def __len__(self): + if self.length is not None: + if self.length == -1: + raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) + return self.length + if all(isinstance(dp, Sized) for dp in self.datapipes): + self.length = sum(len(dp) for dp in self.datapipes) + else: + self.length = -1 + return len(self) @functional_datapipe('zip') -class ZipIterDataPipe(IterDataPipe[Tuple[T_co]]): - r""" :class:`ZipIterDataPipe`. +class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]): + r""" :class:`ZipperIterDataPipe`. Iterable DataPipe aggregates elements into a tuple from each of the input DataPipe. The output DataPipe is stopped when the shortest input DataPipe is exhausted. - args: + + Args: *datapipes: Iterable DataPipes being aggregated """ datapipes: Tuple[IterDataPipe] diff --git a/torch/utils/data/datapipes/iter/listdirfiles.py b/torch/utils/data/datapipes/iter/filelister.py similarity index 61% rename from torch/utils/data/datapipes/iter/listdirfiles.py rename to torch/utils/data/datapipes/iter/filelister.py index 91ef8a3b080a4..aef147d2d2941 100644 --- a/torch/utils/data/datapipes/iter/listdirfiles.py +++ b/torch/utils/data/datapipes/iter/filelister.py @@ -2,15 +2,16 @@ from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root from typing import List, Union, Iterator -class ListDirFilesIterDataPipe(IterDataPipe[str]): - r""" :class:`ListDirFilesIterDataPipe` +class FileListerIterDataPipe(IterDataPipe[str]): + r""" :class:`FileListerIterDataPipe` Iterable DataPipe to load file pathname(s) (path + filename), yield pathname from given disk root dir. - args: - root : root dir - mask : a unix style filter string or string list for filtering file name(s) - abspath : whether to return relative pathname or absolute pathname - length : a nominal length of the datapipe + + Args: + root: Root directory + mask: Unix style filter string or string list for filtering file name(s) + abspath: Whether to return relative pathname or absolute pathname + length: Nominal length of the datapipe """ def __init__( @@ -22,11 +23,11 @@ def __init__( abspath: bool = False, length: int = -1): super().__init__() - self.root : str = root - self.masks : Union[str, List[str]] = masks - self.recursive : bool = recursive - self.abspath : bool = abspath - self.length : int = length + self.root: str = root + self.masks: Union[str, List[str]] = masks + self.recursive: bool = recursive + self.abspath: bool = abspath + self.length: int = length def __iter__(self) -> Iterator[str] : yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath) diff --git a/torch/utils/data/datapipes/iter/loadfilesfromdisk.py b/torch/utils/data/datapipes/iter/fileloader.py similarity index 90% rename from torch/utils/data/datapipes/iter/loadfilesfromdisk.py rename to torch/utils/data/datapipes/iter/fileloader.py index c9dd5daf9a17a..7c048fc054378 100644 --- a/torch/utils/data/datapipes/iter/loadfilesfromdisk.py +++ b/torch/utils/data/datapipes/iter/fileloader.py @@ -5,18 +5,19 @@ from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames -class LoadFilesFromDiskIterDataPipe(IterDataPipe[Tuple[str, IOBase]]): - r""" :class:`LoadFilesFromDiskIterDataPipe`. +class FileLoaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]): + r""" :class:`FileLoaderIterDataPipe`. Iterable Datapipe to load file streams from given pathnames, yield pathname and file stream in a tuple. - args: + + Args: datapipe: Iterable datapipe that provides pathnames mode: An optional string that specifies the mode in which the file is opened by `open()`. It defaults to 'b' which means open for reading in binary mode. Another option is 't' for text mode - length: a nominal length of the datapipe + length: Nominal length of the datapipe Note: The opened file handles will be closed by Python's GC periodly. Users can choose diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 1bd8c4cf4c315..d90ad08814ecf 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,12 +1,10 @@ -import functools -import os import random import warnings from collections import defaultdict from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk -from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict +from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar T_co = TypeVar('T_co', covariant=True) @@ -30,28 +28,35 @@ def __iter__(self): if i % self.num_of_instances == self.instance_id: yield item + def __len__(self): + if isinstance(self.source_datapipe, Sized): + return len(self.source_datapipe) // self.num_of_instances +\ + (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0) + raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) + @functional_datapipe('batch') -class BatchIterDataPipe(IterDataPipe[DataChunk[T_co]]): - r""" :class:`BatchIterDataPipe`. +class BatcherIterDataPipe(IterDataPipe[DataChunk]): + r""" :class:`BatcherIterDataPipe`. Iterable DataPipe to create mini-batches of data. An outer dimension will be added as `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the last batch if `drop_last` is set to `False`. - args: + + Args: datapipe: Iterable DataPipe being batched batch_size: The size of each batch drop_last: Option to drop the last batch if it's not full unbatch_level: Specifies if it necessary to unbatch source data before applying new batching rule """ - datapipe: IterDataPipe[T_co] + datapipe: IterDataPipe batch_size: int drop_last: bool length: Optional[int] def __init__(self, - datapipe: IterDataPipe[T_co], + datapipe: IterDataPipe, batch_size: int, drop_last: bool = False, unbatch_level: int = 0, @@ -68,8 +73,8 @@ def __init__(self, self.length = None self.wrapper_class = DataChunk - def __iter__(self) -> Iterator[DataChunk[T_co]]: - batch: List[T_co] = [] + def __iter__(self) -> Iterator[DataChunk]: + batch: List = [] for x in self.datapipe: batch.append(x) if len(batch) == self.batch_size: @@ -93,18 +98,21 @@ def __len__(self) -> int: @functional_datapipe('unbatch') -class UnBatchIterDataPipe(IterDataPipe): - r""" :class:`UnBatchIterDataPipe`. +class UnBatcherIterDataPipe(IterDataPipe): + r""" :class:`UnBatcherIterDataPipe`. Iterable DataPipe to undo batching of data. In other words, it flattens the data up to the specified level within a batched DataPipe. - args: + + Args: datapipe: Iterable DataPipe being un-batched unbatch_level: Defaults to `1` (only flattening the top level). If set to `2`, it will flatten the top 2 levels, - and `-1` will flatten the entire DataPipe. + and `-1` will flatten the entire DataPipe. """ - def __init__(self, datapipe, unbatch_level: int = 1): + def __init__(self, + datapipe: IterDataPipe, + unbatch_level: int = 1): self.datapipe = datapipe self.unbatch_level = unbatch_level @@ -133,19 +141,20 @@ def _dive(self, element, unbatch_level): else: raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe") -# TODO(ejguan): https://github.com/pytorch/pytorch/issues/63095 + def _in_batch_shuffle_fn(data: DataChunk): - d = list(data) - random.shuffle(d) - return DataChunk(d) + random.shuffle(data) + return data + class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): - r""" :class:`BucketBatcherIterDataPipe`. + r""":class:`BucketBatcherIterDataPipe`. Iterable DataPipe to create mini-batches of data from sorted bucket. An outer dimension will be added as `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the last batch if `drop_last` is set to `False`. - args: + + Args: datapipe: Iterable DataPipe being batched batch_size: The size of each batch drop_last: Option to drop the last batch if it's not full @@ -225,38 +234,22 @@ def __len__(self) -> int: raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) -# defaut group key is the file pathname without the extension. -# Assuming the passed in data is a tuple and 1st item is file's pathname. -def default_group_key_fn(dataitem: Tuple[str, Any]): - return os.path.splitext(dataitem[0])[0] - - -def default_sort_data_fn(datalist: List[Tuple[str, Any]]): - txt_ext = ['.json', '.jsn', '.txt', '.text'] - - def cmp_fn(a: Tuple[str, Any], b: Tuple[str, Any]): - a_is_txt = os.path.splitext(a[0])[1] in txt_ext - b_is_txt = os.path.splitext(b[0])[1] in txt_ext - - # if a is txt but b is not, b go front - if a_is_txt and not b_is_txt: - return 1 - # if a is not txt but b is txt, a go front - if not a_is_txt and b_is_txt: - return -1 - # if a and b both are or are not txt, sort in alphabetic order - if a[0] < b[0]: - return -1 - elif a[0] > b[0]: - return 1 - return 0 - - return sorted(datalist, key=functools.cmp_to_key(cmp_fn)) - - @functional_datapipe('groupby') -class GroupByIterDataPipe(IterDataPipe): - # TODO(VtalyFedyunin): Add inline docs and tests (they are partially available in notebooks) +class GrouperIterDataPipe(IterDataPipe[DataChunk]): + r""":class:`GrouperIterDataPipe`. + + Iterable datapipe to group data from input IterDataPipe by keys which are generated from `group_key_fn`, + and yield a DataChunk with size ranging from `guaranteed_group_size` to `group_size`. + + Args: + datapipe: Iterable datapipe to be grouped + group_key_fn: Function used to generate group key from the data of the source datapipe + buffer_size: The size of buffer for ungrouped data + group_size: The size of each group + unbatch_level: Specifies if it necessary to unbatch source data before grouping + guaranteed_group_size: The guaranteed minimum group size + drop_remaining: Specifies if the group smaller than `guaranteed_group_size` will be dropped from buffer + """ def __init__(self, datapipe: IterDataPipe[T_co], group_key_fn: Callable, @@ -309,6 +302,9 @@ def __iter__(self): for x in self.datapipe: key = self.group_key_fn(x) + buffer_elements[key].append(x) + buffer_size += 1 + if self.group_size is not None and self.group_size == len(buffer_elements[key]): yield self.wrapper_class(buffer_elements[key]) buffer_size -= len(buffer_elements[key]) @@ -319,92 +315,7 @@ def __iter__(self): if result_to_yield is not None: yield self.wrapper_class(result_to_yield) - buffer_elements[key].append(x) - buffer_size += 1 - while buffer_size: (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size) if result_to_yield is not None: yield self.wrapper_class(result_to_yield) - - -@functional_datapipe('group_by_key') -class GroupByKeyIterDataPipe(IterDataPipe[list]): - r""" :class:`GroupByKeyIterDataPipe`. - - Iterable datapipe to group data from input iterable by keys which are generated from `group_key_fn`, - yields a list with `group_size` items in it, each item in the list is a tuple of key and data - - args: - datapipe: Iterable datapipe that provides data. (typically str key (eg. pathname) and data stream in tuples) - group_size: the size of group - max_buffer_size: the max size of stream buffer which is used to store not yet grouped but iterated data - group_key_fn: a function which is used to generate group key from the data in the input datapipe - sort_data_fn: a function which is used to sort the grouped data before yielding back - length: a nominal length of the datapipe - """ - datapipe: IterDataPipe[Tuple[str, Any]] - group_size: int - max_buffer_size: int - group_key_fn: Callable - sort_data_fn: Callable - curr_buffer_size: int - stream_buffer: Dict[str, List[Tuple[str, Any]]] - length: int - - def __init__( - self, - datapipe: IterDataPipe[Tuple[str, Any]], - *, - group_size: int, - max_buffer_size: Optional[int] = None, - group_key_fn: Callable = default_group_key_fn, - sort_data_fn: Callable = default_sort_data_fn, - length: int = -1): - super().__init__() - - assert group_size > 0 - self.datapipe = datapipe - self.group_size = group_size - - # default max buffer size is group_size * 10 - self.max_buffer_size = max_buffer_size if max_buffer_size is not None else group_size * 10 - assert self.max_buffer_size >= self.group_size - - self.group_key_fn = group_key_fn # type: ignore[assignment] - self.sort_data_fn = sort_data_fn # type: ignore[assignment] - self.curr_buffer_size = 0 - self.stream_buffer = {} - self.length = length - - def __iter__(self) -> Iterator[list]: - if self.group_size == 1: - for data in self.datapipe: - yield [data] - else: - for data in self.datapipe: - key = self.group_key_fn(data) - if key not in self.stream_buffer: - self.stream_buffer[key] = [] - res = self.stream_buffer[key] - res.append(data) - if len(res) == self.group_size: - yield self.sort_data_fn(res) - del self.stream_buffer[key] - self.curr_buffer_size = self.curr_buffer_size - self.group_size + 1 - else: - if self.curr_buffer_size == self.max_buffer_size: - raise OverflowError( - "stream_buffer is overflow, please adjust the order of data " - "in the input datapipe or increase the buffer size!") - self.curr_buffer_size = self.curr_buffer_size + 1 - - if self.curr_buffer_size > 0: - msg = "Not able to group [{}] with group size {}.".format( - ','.join([v[0] for _, vs in self.stream_buffer.items() for v in vs]), str(self.group_size)) - raise RuntimeError(msg) - - def __len__(self) -> int: - if self.length == -1: - raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) - return self.length diff --git a/torch/utils/data/datapipes/iter/httpreader.py b/torch/utils/data/datapipes/iter/httpreader.py index c663a18cdaab8..0c8e2fc818e9f 100644 --- a/torch/utils/data/datapipes/iter/httpreader.py +++ b/torch/utils/data/datapipes/iter/httpreader.py @@ -1,5 +1,5 @@ from io import IOBase -from typing import Tuple +from typing import Sized, Tuple from urllib.error import HTTPError, URLError import urllib.request as urllib from torch.utils.data import IterDataPipe @@ -10,16 +10,18 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]): Iterable DataPipe to load file url(s) (http url(s) pointing to file(s)), yield file url and IO stream in a tuple - args: - timeout : timeout for http request + + Args: + datapipe: Iterable DataPipe providing urls + timeout: Timeout for http request """ - def __init__(self, source_datapipe, timeout=None): - self.source_datapipe = source_datapipe + def __init__(self, datapipe, timeout=None): + self.datapipe = datapipe self.timeout = timeout def __iter__(self): - for furl in self.source_datapipe: + for furl in self.datapipe: try: if self.timeout is None: r = urllib.urlopen(furl) @@ -37,3 +39,8 @@ def __iter__(self): .format(reason=e.reason, url=furl)) except Exception: raise + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) diff --git a/torch/utils/data/datapipes/iter/linereader.py b/torch/utils/data/datapipes/iter/linereader.py new file mode 100644 index 0000000000000..04b992d647b77 --- /dev/null +++ b/torch/utils/data/datapipes/iter/linereader.py @@ -0,0 +1,21 @@ +from typing import Tuple +from torch.utils.data import IterDataPipe + + +class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]): + r""" :class:`LineReaderIterDataPipe` + + Iterable DataPipe to load file name and stream as source IterDataPipe + and yield filename and line(s). + + Args: + datapipe: Iterable DataPipe providing file name and string file stream + """ + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for file_name, stream in self.datapipe: + for line in stream: + yield file_name, line diff --git a/torch/utils/data/datapipes/iter/readlinesfromfile.py b/torch/utils/data/datapipes/iter/readlinesfromfile.py deleted file mode 100644 index c8366af3b475f..0000000000000 --- a/torch/utils/data/datapipes/iter/readlinesfromfile.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Tuple -from torch.utils.data import IterDataPipe - - -class ReadLinesFromFileIterDataPipe(IterDataPipe[Tuple[str, str]]): - r""" :class:`ReadLinesFromFileDataPipe` - - Iterable DataPipe to load file names as source iter data pipe - and yield filename and line(s). - """ - - def __init__(self, source_datapipe): - self.source_datapipe = source_datapipe - - def __iter__(self): - for file_name in self.source_datapipe: - with open(file_name) as file: - for line in file: - yield (file_name, line) diff --git a/torch/utils/data/datapipes/iter/routeddecoder.py b/torch/utils/data/datapipes/iter/routeddecoder.py index f149c074e63fe..ea47742f8e80b 100644 --- a/torch/utils/data/datapipes/iter/routeddecoder.py +++ b/torch/utils/data/datapipes/iter/routeddecoder.py @@ -6,7 +6,8 @@ Decoder, basichandlers as decoder_basichandlers, imagehandler as decoder_imagehandler, - extension_extract_fn) + extension_extract_fn +) @functional_datapipe('decode') @@ -15,7 +16,8 @@ class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]): Iterable datapipe to decode binary streams from input DataPipe, yield pathname and decoded data in a tuple. - args: + + Args: datapipe: Iterable datapipe that provides pathname and binary stream in tuples handlers: Optional user defined decoder handlers. If None, basic and image decoder handlers will be set as default. If multiple handles are provided, the priority diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index 46a613a7d91a4..f1889e5d7a8e4 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -1,38 +1,63 @@ +import warnings from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict -from .callable import MapIterDataPipe - T_co = TypeVar('T_co', covariant=True) +try: + import dill + + # XXX: By default, dill writes the Pickler dispatch table to inject its + # own logic there. This globally affects the behavior of the standard library + # pickler for any user who transitively depends on this module! + # Undo this extension to avoid altering the behavior of the pickler globally. + dill.extend(use_dill=False) + DILL_AVAILABLE = True +except ImportError: + DILL_AVAILABLE = False + @functional_datapipe('filter') -class FilterIterDataPipe(MapIterDataPipe): +class FilterIterDataPipe(IterDataPipe[T_co]): r""" :class:`FilterIterDataPipe`. Iterable DataPipe to filter elements from datapipe according to filter_fn. - args: + + Args: datapipe: Iterable DataPipe being filtered filter_fn: Customized function mapping an element to a boolean. fn_args: Positional arguments for `filter_fn` fn_kwargs: Keyword arguments for `filter_fn` drop_empty_batches: By default, drops batch if it is empty after filtering instead of keeping an empty list nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0). - This also accepts -1 as input to apply filtering to the lowest nesting level. It currently doesn't support - argument < -1. + This also accepts -1 as input to apply filtering to the lowest nesting level. + It currently doesn't support argument < -1. """ + datapipe: IterDataPipe + filter_fn: Callable drop_empty_batches: bool def __init__(self, - datapipe: IterDataPipe[T_co], - filter_fn: Callable[..., bool], + datapipe: IterDataPipe, + filter_fn: Callable, fn_args: Optional[Tuple] = None, fn_kwargs: Optional[Dict] = None, drop_empty_batches: bool = True, nesting_level: int = 0, ) -> None: + super().__init__() + self.datapipe = datapipe + # Partial object has no attribute '__name__', but can be pickled + if hasattr(filter_fn, '__name__') and filter_fn.__name__ == '' and not DILL_AVAILABLE: + warnings.warn("Lambda function is not supported for pickle, please use " + "regular python function or functools.partial instead.") + self.filter_fn = filter_fn # type: ignore[assignment] + self.args = () if fn_args is None else fn_args + self.kwargs = {} if fn_kwargs is None else fn_kwargs + if nesting_level < -1: + raise ValueError("nesting_level must be -1 or >= 0") + self.nesting_level = nesting_level self.drop_empty_batches = drop_empty_batches - super().__init__(datapipe, fn=filter_fn, fn_args=fn_args, fn_kwargs=fn_kwargs, nesting_level=nesting_level) def __iter__(self) -> Iterator[T_co]: res: bool @@ -65,7 +90,7 @@ def _applyFilter(self, data, nesting_level): return self._returnIfTrue(data) def _returnIfTrue(self, data): - condition = self.fn(data, *self.args, **self.kwargs) + condition = self.filter_fn(data, *self.args, **self.kwargs) if not isinstance(condition, bool): raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe") if condition: @@ -76,6 +101,17 @@ def _isNonEmpty(self, data): not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches) return r + def __getstate__(self): + if DILL_AVAILABLE: + dill_function = dill.dumps(self.filter_fn) + else: + dill_function = self.filter_fn + state = (self.datapipe, dill_function, self.args, self.kwargs, self.drop_empty_batches, self.nesting_level) + return state - def __len__(self): - raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) + def __setstate__(self, state): + (self.datapipe, dill_function, self.args, self.kwargs, self.drop_empty_batches, self.nesting_level) = state + if DILL_AVAILABLE: + self.filter_fn = dill.loads(dill_function) # type: ignore[assignment] + else: + self.filter_fn = dill_function # type: ignore[assignment] diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py new file mode 100644 index 0000000000000..197fb8e2b3005 --- /dev/null +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -0,0 +1,26 @@ +from typing import Tuple +from torch.utils.data import IterDataPipe + + +class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]): + r""" :class:`StreamReaderIterDataPipe` + + Iterable DataPipe to load IO stream with label name, + and to yield bytes with label name in a tuple + + Args: + datapipe: Iterable DataPipe provides url and byte stream + chunk: Number of bytes to be read from stream per iteration. + If None, all bytes will be read util the EOF. + """ + def __init__(self, datapipe, chunk=None): + self.datapipe = datapipe + self.chunk = chunk + + def __iter__(self): + for furl, stream in self.datapipe: + while True: + d = stream.read(self.chunk) + if not d: + break + yield (furl, d) diff --git a/torch/utils/data/datapipes/iter/readfilesfromtar.py b/torch/utils/data/datapipes/iter/tararchivereader.py similarity index 90% rename from torch/utils/data/datapipes/iter/readfilesfromtar.py rename to torch/utils/data/datapipes/iter/tararchivereader.py index f4566021fcc7f..c34583a4d9420 100644 --- a/torch/utils/data/datapipes/iter/readfilesfromtar.py +++ b/torch/utils/data/datapipes/iter/tararchivereader.py @@ -7,14 +7,16 @@ import tarfile import warnings -class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""":class:`ReadFilesFromTarIterDataPipe`. +class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): + r""" :class:`TarArchiveReaderIterDataPipe`. Iterable datapipe to extract tar binary streams from input iterable which contains tuples of pathname and tar binary stream, yields pathname and extracted binary stream in a tuple. - args: + + Args: datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples - mode: File mode used by `tarfile.open` to read file object. Mode has to be a string of the form 'filemode[:compression]' + mode: File mode used by `tarfile.open` to read file object. + Mode has to be a string of the form 'filemode[:compression]' length: a nominal length of the datapipe Note: @@ -24,13 +26,13 @@ class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): """ def __init__( self, - datapipe : Iterable[Tuple[str, BufferedIOBase]], + datapipe: Iterable[Tuple[str, BufferedIOBase]], mode: str = "r:*", - length : int = -1 + length: int = -1 ): super().__init__() self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.mode = mode + self.mode: str = mode self.length: int = length def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: diff --git a/torch/utils/data/datapipes/iter/tobytes.py b/torch/utils/data/datapipes/iter/tobytes.py deleted file mode 100644 index 21fd82d381bcb..0000000000000 --- a/torch/utils/data/datapipes/iter/tobytes.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Tuple -from torch.utils.data import IterDataPipe - - -class ToBytesIterDataPipe(IterDataPipe[Tuple[str, bytes]]): - r""" :class:`ToBytesIterDataPipe` - - Iterable DataPipe to load IO stream with label name, - and to yield bytes with label name in a tuple - args: - chunk : bytes to read from stream on each iteration. - If None, stream reads to the EOF. - """ - def __init__(self, source_datapipe, chunk=None): - self.source_datapipe = source_datapipe - self.chunk = chunk - - def __iter__(self): - for (furl, stream) in self.source_datapipe: - while True: - d = stream.read(self.chunk) - if not d: - break - yield (furl, d) diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py new file mode 100644 index 0000000000000..9ba80e3576f77 --- /dev/null +++ b/torch/utils/data/datapipes/iter/utils.py @@ -0,0 +1,20 @@ +from torch.utils.data import IterDataPipe + + +class IterableWrapperIterDataPipe(IterDataPipe): + r""":class:`IterableWrapperIterDataPipe`. + + Iterable datapipe that wraps an iterable object. + + Args: + iterable: Iterable object to be wrapped into an IterDataPipe + """ + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for data in self.iterable: + yield data + + def __len__(self): + return len(self.iterable) diff --git a/torch/utils/data/datapipes/iter/readfilesfromzip.py b/torch/utils/data/datapipes/iter/ziparchivereader.py similarity index 87% rename from torch/utils/data/datapipes/iter/readfilesfromzip.py rename to torch/utils/data/datapipes/iter/ziparchivereader.py index edb8320aece9f..881d00598151a 100644 --- a/torch/utils/data/datapipes/iter/readfilesfromzip.py +++ b/torch/utils/data/datapipes/iter/ziparchivereader.py @@ -8,14 +8,15 @@ import zipfile import warnings -class ReadFilesFromZipIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""" :class:`ReadFilesFromZipIterDataPipe`. +class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): + r""" :class:`ZipArchiveReaderIterDataPipe`. Iterable data pipe to extract zip binary streams from input iterable which contains tuples of pathname and zip binary stream, yields pathname and extracted binary stream in a tuple. - args: + + Args: datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples - length: a nominal length of the datapipe + length: Nominal length of the datapipe Note: The opened file handles will be closed automatically if the default DecoderDataPipe @@ -24,12 +25,11 @@ class ReadFilesFromZipIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): """ def __init__( self, - datapipe : Iterable[Tuple[str, BufferedIOBase]], - length : int = -1): + datapipe: Iterable[Tuple[str, BufferedIOBase]], + length: int = -1): super().__init__() - self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.length : int = length - + self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe + self.length: int = length def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: if not isinstance(self.datapipe, Iterable): @@ -60,7 +60,6 @@ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e)) raise e - def __len__(self): if self.length == -1: raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) diff --git a/torch/utils/data/datapipes/map/__init__.py b/torch/utils/data/datapipes/map/__init__.py index b7609957baaa8..5879165aff2eb 100644 --- a/torch/utils/data/datapipes/map/__init__.py +++ b/torch/utils/data/datapipes/map/__init__.py @@ -1,7 +1,6 @@ # Functional DataPipe -from torch.utils.data.datapipes.map.callable import MapMapDataPipe as Map -from torch.utils.data.datapipes.map.combining import \ - (ConcatMapDataPipe as Concat) +from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper +from torch.utils.data.datapipes.map.combining import ConcaterMapDataPipe as Concater -__all__ = ['Map', 'Concat'] +__all__ = ['Concater', 'Mapper'] diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index 00457299316ae..8dbad957e069d 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -26,8 +26,8 @@ def default_fn(data): @functional_datapipe('map') -class MapMapDataPipe(MapDataPipe[T_co]): - r""":class:`MapMapDataPipe`. +class MapperMapDataPipe(MapDataPipe[T_co]): + r""":class:`MapperMapDataPipe`. Map DataPipe to run a function over each item from the source DataPipe. The function can be any regular python function or partial object. Lambda diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 234d45382efe6..4743c3726b356 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -5,8 +5,8 @@ @functional_datapipe('concat') -class ConcatMapDataPipe(MapDataPipe): - r""" :class:`ConcatMapDataPipe`. +class ConcaterMapDataPipe(MapDataPipe): + r""" :class:`ConcaterMapDataPipe`. Map DataPipe to concatenate multiple Map DataPipes. The actual index of is the cumulative sum of source datapipes. diff --git a/torch/utils/data/datapipes_tutorial_dev_loaders.ipynb b/torch/utils/data/datapipes_tutorial_dev_loaders.ipynb deleted file mode 100644 index 0a9b834a86862..0000000000000 --- a/torch/utils/data/datapipes_tutorial_dev_loaders.ipynb +++ /dev/null @@ -1,178 +0,0 @@ -{ - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.10" - }, - "orig_nbformat": 2, - "kernelspec": { - "name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab", - "display_name": "Python 3.6.10 64-bit ('dataloader': conda)" - } - }, - "nbformat": 4, - "nbformat_minor": 2, - "cells": [ - { - "source": [ - "## DataPipes development tutorial. Loaders DataPipes." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "source": [ - "As DataSet now constructed by stacking `DataPipe`-s it is recommended to keep `DataPipe` functionality as primitive as possible. For example loading data from CSV file will look like sequence of DataPipes: ListFiles FileLoader CSVParser.\n", - "\n" - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "source": [ - "`ExampleListFilesDataPipe` scans all files in `root` folder and yields full file names. Avoid loading entire list in `__init__` function to save memory." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import csv\n", - "import io\n", - "import os\n", - "\n", - "from torch.utils.data import IterDataPipe, functional_datapipe\n", - "\n", - "\n", - "class ExampleListFilesDataPipe(IterDataPipe):\n", - " def __init__(self, *, root):\n", - " self.root = root\n", - "\n", - " def __iter__(self):\n", - " for (dirpath, dirnames, filenames) in os.walk(self.root):\n", - " for file_name in filenames:\n", - " yield os.path.join(dirpath, file_name)" - ] - }, - { - "source": [ - "`ExampleFileLoaderDataPipe` registered as `load_files_as_string` consumes file names from source_datapipe and yields file names and file lines." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "@functional_datapipe('load_files_as_string')\n", - "class ExampleFileLoaderDataPipe(IterDataPipe):\n", - " def __init__(self, source_datapipe):\n", - " self.source_datapipe = source_datapipe\n", - "\n", - " def __iter__(self):\n", - " for file_name in self.source_datapipe:\n", - " with open(file_name) as file:\n", - " lines = file.read()\n", - " yield (file_name, lines)\n" - ] - }, - { - "source": [ - "`ExampleCSVParserDataPipe` registered as `parse_csv_files` consumes file lines and expands them as CSV rows." - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "@functional_datapipe('parse_csv_files')\n", - "class ExampleCSVParserDataPipe(IterDataPipe):\n", - " def __init__(self, source_datapipe):\n", - " self.source_datapipe = source_datapipe\n", - "\n", - " def __iter__(self):\n", - " for file_name, lines in self.source_datapipe:\n", - " reader = csv.reader(io.StringIO(lines))\n", - " for row in reader:\n", - " yield [file_name] + row\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "['/home/vitaly/dataset/data/datapipes/load/iter/test/example_2.csv', '10', \" 'foo'\"]\n['/home/vitaly/dataset/data/datapipes/load/iter/test/example_2.csv', '11', \" 'bar'\"]\n['/home/vitaly/dataset/data/datapipes/load/iter/test/example_1.csv', '12', \" 'aaaa'\"]\n['/home/vitaly/dataset/data/datapipes/load/iter/test/example_1.csv', '13', \" 'bbbb'\"]\n" - ] - } - ], - "source": [ - "FOLDER = 'define your folder with csv files here'\n", - "FOLDER = '/home/vitaly/dataset/data'\n", - "dp = ExampleListFilesDataPipe(root = FOLDER).filter(lambda filename: filename.endswith('.csv')).load_files_as_string().parse_csv_files()\n", - "\n", - "for data in dp:\n", - " print(data)" - ] - }, - { - "source": [ - "This approach allows to replace any DataPipe to get different functionality. For example you can pick individual files.\n" - ], - "cell_type": "markdown", - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "['/home/vitaly/dataset/data/datapipes/load/iter/test/example_1.csv', '12', \" 'aaaa'\"]\n['/home/vitaly/dataset/data/datapipes/load/iter/test/example_1.csv', '13', \" 'bbbb'\"]\n" - ] - } - ], - "source": [ - "FILE = 'define your file with csv data here'\n", - "FILE = '/home/vitaly/dataset/data/datapipes/load/iter/test/example_1.csv'\n", - "dp = ExampleFileLoaderDataPipe([FILE]).parse_csv_files()\n", - "\n", - "for data in dp:\n", - " print(data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ] -} \ No newline at end of file diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index 5b8102c235607..50488d13ae5d3 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -25,25 +25,17 @@ T = TypeVar('T') -class DataChunk(List[T]): +class DataChunk(list, Generic[T]): def __init__(self, items): + super().__init__(items) self.items = items - def __getitem__(self, key): - return self.items[key] - - def __len__(self): - return len(self.items) - def as_str(self, indent=''): - res = indent + "[" + ", ".join([str(i) for i in iter(self)]) + "]" + res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" return res - def __repr__(self): - return self.as_str() - def __iter__(self) -> Iterator[T]: - for i in self.items: + for i in super().__iter__(): yield i def raw_iterator(self): @@ -279,9 +271,8 @@ def cumsum(sequence): def __init__(self, datasets: Iterable[Dataset]) -> None: super(ConcatDataset, self).__init__() - # Cannot verify that datasets is Sized - assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] self.datasets = list(datasets) + assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] for d in self.datasets: assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets) diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 6697f1e014cf7..ad2903f7ad655 100644 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -750,7 +750,6 @@ def repl(m): or f.startswith("ATen/native/quantized/cuda") or f.startswith("ATen/native/sparse/cuda") or f.startswith("THC/") - or f.startswith("THCUNN/") or (f.startswith("THC") and not f.startswith("THCP")) ): return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) diff --git a/ubsan.supp b/ubsan.supp index 62e64b785b94c..395f5208c8437 100644 --- a/ubsan.supp +++ b/ubsan.supp @@ -1 +1,2 @@ vptr:libtorch_python.so +vptr:test_jit