diff --git a/.circleci/config.yml b/.circleci/config.yml index 4cf4dc4e2c6a..3fceba2db8dc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -574,7 +574,7 @@ jobs: hostname export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size=8g --ipc=host --device /dev/kfd --device /dev/dri --group-add video -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) else - export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) + export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size=1g --ipc=host -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) fi echo "id=${id}" >> "${BASH_ENV}" diff --git a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml index 8cbb9a4e3f40..99b327c275a0 100644 --- a/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/job-specs/pytorch-job-specs.yml @@ -133,7 +133,7 @@ jobs: hostname export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size=8g --ipc=host --device /dev/kfd --device /dev/dri --group-add video -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) else - export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) + export id=$(docker run --env-file "${BASH_ENV}" --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size=1g --ipc=host -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) fi echo "id=${id}" >> "${BASH_ENV}" diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py new file mode 100644 index 000000000000..d95518bc6dae --- /dev/null +++ b/.github/scripts/generate_binary_build_matrix.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +"""Generates a matrix to be utilized through github actions + +Will output a condensed version of the matrix if on a pull request that only +includes the latest version of python we support built on three different +architectures: + * CPU + * Latest CUDA + * Latest ROCM +""" + +import json +import os +import itertools + +CUDA_ARCHES = [ + "10.1", + "10.2", + "11.0" +] + +ROCM_ARCHES = [ + "3.10", + "4.0" +] + +FULL_ARCHES = [ + "cpu", + *CUDA_ARCHES, + *ROCM_ARCHES +] + +CONTAINER_IMAGES = { + **{ + # TODO: Re-do manylinux CUDA image tagging scheme to be similar to + # ROCM so we don't have to do this replacement + gpu_arch: f"pytorch/manylinux-cuda{gpu_arch.replace('.', '')}" + for gpu_arch in CUDA_ARCHES + }, + **{ + gpu_arch: f"pytorch/manylinux-rocm:{gpu_arch}" + for gpu_arch in ROCM_ARCHES + }, + "cpu": "pytorch/manylinux-cpu" +} + +FULL_PYTHON_VERSIONS = [ + "3.6", + "3.7", + "3.8", + "3.9", +] + + +def is_pull_request(): + return os.environ.get("GITHUB_HEAD_REF") + +def generate_matrix(): + python_versions = FULL_PYTHON_VERSIONS + arches = FULL_ARCHES + if is_pull_request(): + python_versions = [python_versions[-1]] + arches = ["cpu", CUDA_ARCHES[-1], ROCM_ARCHES[-1]] + matrix = [] + for item in itertools.product(python_versions, arches): + python_version, arch_version = item + # Not my favorite code here + gpu_arch_type = "cuda" + if "rocm" in CONTAINER_IMAGES[arch_version]: + gpu_arch_type = "rocm" + elif "cpu" in CONTAINER_IMAGES[arch_version]: + gpu_arch_type = "cpu" + matrix.append({ + "python_version": python_version, + "gpu_arch_type": gpu_arch_type, + "gpu_arch_version": arch_version, + "container_image": CONTAINER_IMAGES[arch_version] + }) + return json.dumps({"include": matrix}) + +def main(): + print(generate_matrix()) + +if __name__ == "__main__": + main() diff --git a/.github/scripts/generate_pytorch_version.py b/.github/scripts/generate_pytorch_version.py new file mode 100755 index 000000000000..93fc4ca6db3a --- /dev/null +++ b/.github/scripts/generate_pytorch_version.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +import argparse +import os +import subprocess +import re + +from datetime import datetime +from distutils.util import strtobool +from pathlib import Path + +LEADING_V_PATTERN = re.compile("^v") +TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$") +LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") + +class NoGitTagException(Exception): + pass + +def get_pytorch_root(): + return Path(subprocess.check_output( + ['git', 'rev-parse', '--show-toplevel'] + ).decode('ascii').strip()) + +def get_tag(): + root = get_pytorch_root() + # We're on a tag + am_on_tag = ( + subprocess.run( + ['git', 'describe', '--tags', '--exact'], + cwd=root, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL + ).returncode == 0 + ) + tag = "" + if am_on_tag: + dirty_tag = subprocess.check_output( + ['git', 'describe'], + cwd=root + ).decode('ascii').strip() + # Strip leading v that we typically do when we tag branches + # ie: v1.7.1 -> 1.7.1 + tag = re.sub(LEADING_V_PATTERN, "", dirty_tag) + # Strip trailing rc pattern + # ie: 1.7.1-rc1 -> 1.7.1 + tag = re.sub(TRAILING_RC_PATTERN, "", tag) + return tag + +def get_base_version(): + root = get_pytorch_root() + dirty_version = open(root / 'version.txt', 'r').read().strip() + # Strips trailing a0 from version.txt, not too sure why it's there in the + # first place + return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version) + +class PytorchVersion: + def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix): + self.gpu_arch_type = gpu_arch_type + self.gpu_arch_version = gpu_arch_version + self.no_build_suffix = no_build_suffix + + def get_post_build_suffix(self): + # CUDA 10.2 is the version to be uploaded to PyPI so it doesn't have a + # version suffix + if ((self.gpu_arch_type == "cuda" and self.gpu_arch_version == "10.2") + or self.no_build_suffix): + return "" + if self.gpu_arch_type == "cuda": + return f"+cu{self.gpu_arch_version.replace('.', '')}" + return f"+{self.gpu_arch_type}{self.gpu_arch_version}" + + def get_release_version(self): + if not get_tag(): + raise NoGitTagException( + "Not on a git tag, are you sure you want a release version?" + ) + return f"{get_tag()}{self.get_post_build_suffix()}" + + def get_nightly_version(self): + date_str = datetime.today().strftime('%Y%m%d') + build_suffix = self.get_post_build_suffix() + return f"{get_base_version()}.dev{date_str}{build_suffix}" + +def main(): + parser = argparse.ArgumentParser( + description="Generate pytorch version for binary builds" + ) + parser.add_argument( + "--no-build-suffix", + type=strtobool, + help="Whether or not to add a build suffix typically (+cpu)", + default=os.environ.get("NO_BUILD_SUFFIX", False) + ) + parser.add_argument( + "--gpu-arch-type", + type=str, + help="GPU arch you are building for, typically (cpu, cuda, rocm)", + default=os.environ.get("GPU_ARCH_TYPE", "cpu") + ) + parser.add_argument( + "--gpu-arch-version", + type=str, + help="GPU arch version, typically (10.2, 4.0), leave blank for CPU", + default=os.environ.get("GPU_ARCH_VERSION", "") + ) + args = parser.parse_args() + version_obj = PytorchVersion( + args.gpu_arch_type, + args.gpu_arch_version, + args.no_build_suffix + ) + try: + print(version_obj.get_release_version()) + except NoGitTagException: + print(version_obj.get_nightly_version()) + +if __name__ == "__main__": + main() diff --git a/.github/workflows/build_linux_binaries.yml b/.github/workflows/build_linux_binaries.yml new file mode 100644 index 000000000000..fc8917d74625 --- /dev/null +++ b/.github/workflows/build_linux_binaries.yml @@ -0,0 +1,86 @@ +name: Build Linux Wheels + +on: + # TODO: These are only runnable from workflow_dispatch, we need to eventually add + # a cron + # TODO: Add an on_release trigger to build on tags + workflow_dispatch: + +jobs: + generate-build-matrix: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + container: + image: python:3.9 + steps: + - name: Clone pytorch/pytorch + uses: actions/checkout@v2 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python .github/scripts/generate_binary_build_matrix.py + MATRIX=$(python .github/scripts/generate_binary_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + build-wheel: + if: ${{ github.repository_owner == 'pytorch' }} + needs: generate-build-matrix + runs-on: linux.2xlarge + strategy: + matrix: + ${{ fromJson(needs.generate-build-matrix.outputs.matrix) }} + container: + image: ${{ matrix.container_image }} + env: + DESIRED_PYTHON: ${{ matrix.python_version }} + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: ${{ matrix.gpu_arch_version }} + GPU_ARCH_VERSION: ${{ matrix.GPU_ARCH_VERSION }} + GPU_ARCH_TYPE: ${{ matrix.gpu_arch_type }} + PYTORCH_BUILD_NUMBER: 1 + SKIP_ALL_TESTS: 1 + steps: + - name: Clone pytorch/pytorch + uses: actions/checkout@v2 + with: + path: pytorch + submodules: recursive + - name: Clone pytorch/builder + uses: actions/checkout@v2 + with: + repository: pytorch/builder + path: builder + - name: Generate version string + working-directory: pytorch/ + run: | + version=$(.github/scripts/generate_pytorch_version.py) + echo "Generated version: ${version}" + echo "PYTORCH_BUILD_VERSION=${version}" >> $GITHUB_ENV + # TODO: Remove this once we remove the need for the directories to be + # in specific locations + - name: Symlink repositories to root directory (for legacy scripts purposes) + run: | + ln -s $(pwd)/pytorch /pytorch + ln -s $(pwd)/builder /builder + # TODO: Bundle the correct build script in the base container image so + # that we don't have to do this type of specification + - name: Build PyTorch binary (CUDA specific) + if: ${{ matrix.gpu_arch_type == 'cuda' }} + run: | + /builder/manywheel/build.sh + - name: Build PyTorch binary (ROCM specific) + if: ${{ matrix.gpu_arch_type == 'rocm' }} + run: | + /builder/manywheel/build_rocm.sh + - name: Build PyTorch binary (CPU specific) + if: ${{ matrix.gpu_arch_type == 'cpu' }} + run: | + /builder/manywheel/build_cpu.sh + - uses: actions/upload-artifact@v2 + with: + name: pytorch-wheel-py${{ matrix.python_version }}-${{matrix.gpu_arch_type}}-${{ matrix.gpu_arch_version }} + path: /remote/**/*.whl + # TODO: Add a step here for uploading binaries diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index e6f43b6452cf..ac131ba738ca 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -160,7 +160,7 @@ pip install --user pytest-sugar if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # Check out torch/vision at Jun 11 2020 commit # This hash must match one in .jenkins/pytorch/test.sh - pip install -q --user git+https://github.com/pytorch/vision.git@e70c91a9ff9b8a20e05c133aec6ec3ed538c32fb + pip install -q --user git+https://github.com/pytorch/vision.git@ae0d80b3c52dc98b3a9763bdb974c3ef7b6eb83d pip install -q --user ninja # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index b28dcb2f41d8..38799ab782de 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -66,7 +66,7 @@ function get_bazel() { chmod +x tools/bazel } -TORCHVISION_COMMIT=e70c91a9ff9b8a20e05c133aec6ec3ed538c32fb +TORCHVISION_COMMIT=ae0d80b3c52dc98b3a9763bdb974c3ef7b6eb83d function install_torchvision() { # Check out torch/vision at Jun 11 2020 commit diff --git a/aten/src/ATen/LegacyTHFunctionsCUDA.h b/aten/src/ATen/LegacyTHFunctionsCUDA.h index a9004a5b6d01..069a2c1152c4 100644 --- a/aten/src/ATen/LegacyTHFunctionsCUDA.h +++ b/aten/src/ATen/LegacyTHFunctionsCUDA.h @@ -20,8 +20,6 @@ namespace cuda { Tensor & _th_masked_fill_(Tensor & self, const Tensor & mask, Scalar value); Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, Scalar value); -Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source); -Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source); Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source); Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index); Tensor _th_take(const Tensor & self, const Tensor & index); diff --git a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp index ddae7965dded..b60968a4b041 100644 --- a/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp +++ b/aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp @@ -200,166 +200,6 @@ Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, Scalar value) } return self; } -Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBoolTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_", false, DeviceType::CUDA, ScalarType::Byte); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_maskedCopy(globalContext().getTHCState(), self_, mask_, source_); - break; - } - default: - AT_ERROR("_th_masked_scatter_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} -Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Bool: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBoolTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Byte: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaByteTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Char: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaCharTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaDoubleTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Int: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaIntTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Long: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaLongTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Short: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaShortTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::Half: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaHalfTensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - case ScalarType::BFloat16: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - auto mask_ = checked_dense_tensor_unwrap(mask, "mask", 2, "_th_masked_scatter_bool_", false, DeviceType::CUDA, ScalarType::Bool); - auto source_ = checked_dense_tensor_unwrap(source, "source", 3, "_th_masked_scatter_bool_", false, DeviceType::CUDA, dispatch_scalar_type); - THCudaBFloat16Tensor_maskedCopyBool(globalContext().getTHCState(), self_, mask_, source_); - break; - } - default: - AT_ERROR("_th_masked_scatter_bool_ not supported on CUDAType for ", dispatch_scalar_type); - } - return self; -} Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index a4b6e4df43f8..47ff70b93231 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2223,17 +2223,17 @@ Tensor chain_matmul(TensorList matrices) { Calculates the Kronecker product between two Tensors. */ Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto maxdim = std::max(self.dim(), other.dim()); - auto pad_self = maxdim - self.dim(); - auto pad_other = maxdim - other.dim(); + int64_t maxdim = std::max(self.dim(), other.dim()); + int64_t pad_self = maxdim - self.dim(); + int64_t pad_other = maxdim - other.dim(); c10::SmallVector a_reshape(2 * maxdim); c10::SmallVector b_reshape(2 * maxdim); c10::SmallVector result_reshape(maxdim); - for (int i = 0; i < maxdim; i++) { - a_reshape[2 * i] = i >= pad_self ? self.sizes()[i - pad_self] : 1; + for (int64_t i = 0; i < maxdim; i++) { + a_reshape[2 * i] = (i >= pad_self ? self.sizes()[i - pad_self] : 1); a_reshape[2 * i + 1] = 1; b_reshape[2 * i] = 1; - b_reshape[2 * i + 1] = i >= pad_other ? other.sizes()[i - pad_other] : 1; + b_reshape[2 * i + 1] = (i >= pad_other ? other.sizes()[i - pad_other] : 1); result_reshape[i] = a_reshape[2 * i] * b_reshape[2 * i + 1]; } auto self_view = at::_unsafe_view(self, a_reshape); @@ -2241,8 +2241,14 @@ Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { if (!result.defined()) { result = at::_unsafe_view(at::mul(self_view, other_view), result_reshape); } else { - at::mul_out(result, self_view, other_view); - result.resize_(result_reshape); + c10::SmallVector mul_shape(2 * maxdim); + for (int64_t i = 0; i < maxdim; i++) { + mul_shape[2 * i] = a_reshape[2 * i]; + mul_shape[2 * i + 1] = b_reshape[2 * i + 1]; + } + resize_output(result, result_reshape); + auto result_mul = at::_unsafe_view(result, mul_shape); + at::mul_out(result_mul, self_view, other_view); } return result; } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index ca311f86091e..9ff806bd5054 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -650,38 +650,9 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(c10::pi) / 4.)); } -// NB: If you use this macro, you may also need to add a CUDA forwarding -// stub in CUDAUnaryOps - -#define IMPLEMENT_UNARY_OP_CORE(op) \ - Tensor op(const Tensor& self) { \ - Tensor result = at::empty({0}, self.options()); \ - at::op##_out(result, self); \ - return result; \ - } - -#define IMPLEMENT_UNARY_OP_OUT_INPLACE(op, prefix, device) \ - Tensor& _##op##__##prefix(Tensor& self) { \ - return at::op##_out(self, self); \ - } \ - Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \ - checkDeviceType(#op, result, DeviceType::device); \ - checkLayout(#op, result, Layout::Strided); \ - auto iter = TensorIterator::unary_op(result, self); \ - op##_stub(iter.device_type(), iter); \ - return result; \ - } - -#define IMPLEMENT_UNARY_OP_VEC(op) \ - IMPLEMENT_UNARY_OP_CORE(op) \ - IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) - -#define IMPLEMENT_UNARY_OP_VEC_CUDA(op) \ - IMPLEMENT_UNARY_OP_CORE(op) \ - IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \ - IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA) - -IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma) +Tensor& lgamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, lgamma_stub); } +Tensor lgamma(const Tensor& self) { return unary_op_impl_float(self, lgamma_stub); } +Tensor& lgamma_(Tensor& self) { return unary_op_impl_(self, at::lgamma_out); } DEFINE_DISPATCH(abs_stub); DEFINE_DISPATCH(angle_stub); diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index d88f202487af..091e1ec22a19 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -10,7 +10,13 @@ #include #include #include +#include #include +#include + +#include +#include +#include namespace at { namespace native { @@ -252,6 +258,103 @@ Tensor& take_out_cuda(Tensor& out, const Tensor& self, const Tensor& index) { return out; } +namespace { + +template +void masked_scatter_cuda_impl(Tensor& self, const Tensor& mask, const Tensor& source){ + auto srcSize = source.numel(); + + // Determine our output size + auto totalElements = mask.sum().item(); + + // The number of `1` elements present in the mask must be <= the + // number of elements available in `src` + TORCH_CHECK(totalElements <= srcSize, "source nElements must be == mask `1` elements"); + + auto mask_cont = mask.contiguous(); + + // Use a prefix sum to determine the output locations of the masked elements + auto maskPrefixSum = at::empty_like(mask, mask.options().dtype(kLong)); + + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + + thrust::device_ptr maskData(mask_cont.data_ptr()); + thrust::device_ptr maskPrefixSumData( + maskPrefixSum.data_ptr()); + + thrust::exclusive_scan( + thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()), + maskData, + maskData + mask_cont.numel(), + maskPrefixSumData); + + // We are getting elements from `src` based on an offset from + // `maskPrefixSum`, so that should be made contiguous too + auto source_contig = source.contiguous(); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self) + .add_input(self) + .add_input(mask_cont) + .add_input(maskPrefixSum) + .build(); + + AT_DISPATCH_ALL_TYPES_AND3( + ScalarType::Bool, + ScalarType::BFloat16, + ScalarType::Half, + self.scalar_type(), + "masked_scatter_", + [&]() { + auto source_ptr = source_contig.data_ptr(); + gpu_kernel( + iter, [=] GPU_LAMBDA(scalar_t a, mask_t mask, int64_t maskPrefixSum) -> scalar_t { + if (mask) { + return source_ptr[maskPrefixSum]; + } + return a; + }); + cudaGetLastError(); + }); +} + +} // anonymous namespace + +Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& source) { + at::assert_no_internal_overlap(self); + TORCH_CHECK( + self.scalar_type() == source.scalar_type(), + "masked_scatter: expected self and source to have same dtypes but got", + self.scalar_type(), + " and ", + source.scalar_type()); + + TensorArg self_arg{self, "self", 1}; + TensorArg mask_arg{mask, "mask", 2}; + TensorArg source_arg{source, "source", 3}; + checkAllSameGPU("masked_scatter_", {self_arg, mask_arg, source_arg}); + + Tensor b_mask; + std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_"); + + if (b_mask.dtype() == ScalarType::Byte) { + TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ + "please use a mask with dtype torch.bool instead."); + } + + auto mask_dtype = b_mask.scalar_type(); + if (mask_dtype == ScalarType::Bool) { + masked_scatter_cuda_impl(self, b_mask, source); + } else { + masked_scatter_cuda_impl(self, b_mask, source); + } + + return self; +} + REGISTER_DISPATCH(index_stub, &index_kernel); REGISTER_DISPATCH(index_put_stub, &index_put_kernel); diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 035dc188c81c..6b3304cff421 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -230,7 +230,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List std::min(std::max(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2])); dim3 block(C10_WARP_SIZE, indices_per_block); - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, value_.scalar_type(), "indexing_backward", [&] { indexing_backward_kernel<<>>( sorted_indices.data_ptr(), diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index 1bbe47dbfb2e..735f2c8b2875 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -61,19 +61,4 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val return self; } -Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & source) { - at::assert_no_internal_overlap(self); - Tensor b_mask; - std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_"); - // As we dispatch on self and TH is type-checked, we need different definitions. - // This can be fixed by moving to ATen. - if (b_mask.dtype() == at::ScalarType::Byte) { - TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ - "please use a mask with dtype torch.bool instead."); - return legacy::cuda::_th_masked_scatter_(self, b_mask, source); - } else { - return legacy::cuda::_th_masked_scatter_bool_(self, b_mask, source); - } -} - }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index 97dbeefccc77..cdcf92e719d8 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -41,7 +41,7 @@ void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) { } void lgamma_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "lgamma_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "lgamma_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return ::lgamma(a); }); diff --git a/aten/src/ATen/native/mkl/LinearAlgebra.cpp b/aten/src/ATen/native/mkl/LinearAlgebra.cpp index 0fc22c2c637d..cb14f9ae3333 100644 --- a/aten/src/ATen/native/mkl/LinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/LinearAlgebra.cpp @@ -32,6 +32,36 @@ Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, namespace at { namespace native { +static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, + const int M, const int N, const int K, const float alpha, const float* A, + const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { + cblas_sgemm(CblasRowMajor, trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +} + +static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, + const int M, const int N, const int K, const double alpha, const double* A, + const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { + cblas_dgemm(CblasRowMajor, trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +} + +static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, + const int M, const int N, const int K, const c10::complex alpha, + const c10::complex* A, const int lda, const c10::complex* B, const int ldb, + const c10::complex beta, c10::complex* C, const int ldc) { + cblas_cgemm(CblasRowMajor, trans_A, trans_B, M, N, K, reinterpret_cast(&alpha), + reinterpret_cast(A), lda, reinterpret_cast(B), ldb, + reinterpret_cast(&beta), reinterpret_cast(C), ldc); +} + +static inline void gemm(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, + const int M, const int N, const int K, const c10::complex alpha, + const c10::complex* A, const int lda, const c10::complex* B, const int ldb, + const c10::complex beta, c10::complex* C, const int ldc) { + cblas_zgemm(CblasRowMajor, trans_A, trans_B, M, N, K, reinterpret_cast(&alpha), + reinterpret_cast(A), lda, reinterpret_cast(B), ldb, + reinterpret_cast(&beta), reinterpret_cast(C), ldc); +} + static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const float** A, const int lda, const float** B, const int ldb, const float beta, @@ -101,6 +131,31 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c const int ldb = trans_B == CblasTrans ? mat2_strides[2] : mat2_strides[1]; const int ldc = res.strides()[1]; + // avoid using tensor accessor in the case of mat1/mat2 not being transposed + // or only transposed in the last two axes + const bool canAvoidTensorAccessor = mat1_strides[0] == mat1_sizes[1] * mat1_sizes[2] && + mat2_strides[0] == mat2_sizes[1] * mat2_sizes[2]; + + scalar_t* const res_data = static_cast(res.data_ptr()); + + if (batch_size == 1) { + const scalar_t* A; + const scalar_t* B; + if (canAvoidTensorAccessor) { + scalar_t* mat1_data = static_cast(mat1.data_ptr()); + scalar_t* mat2_data = static_cast(mat2.data_ptr()); + A = mat1_data; + B = mat2_data; + } else { + auto mat1_acc = mat1.accessor(); + auto mat2_acc = mat2.accessor(); + A = mat1_acc[0].data(); + B = mat2_acc[0].data(); + } + gemm(trans_A, trans_B, M, N, K, alpha, A, lda, B, ldb, beta, res_data, ldc); + return; + } + std::vector A; A.reserve(batch_size); std::vector B; @@ -110,10 +165,8 @@ static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, c // avoid using tensor accessor in the case of mat1/mat2 not being transposed // or only transposed in the last two axis - scalar_t* res_data = static_cast(res.data_ptr()); const auto res_sizes = res.sizes(); - if (mat1_strides[0] == mat1_sizes[1] * mat1_sizes[2] && - mat2_strides[0] == mat2_sizes[1] * mat2_sizes[2]) { + if (canAvoidTensorAccessor) { scalar_t* mat1_data = static_cast(mat1.data_ptr()); scalar_t* mat2_data = static_cast(mat2.data_ptr()); for (int64_t batch = 0; batch < batch_size; batch++) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b9cd22289f85..5f0d21f6b9b2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5158,12 +5158,6 @@ dispatch: CPU, CUDA: __irshift__ -- func: lgamma_(Tensor(a!) self) -> Tensor(a!) - variants: method - dispatch: - CPU: _lgamma__cpu - CUDA: _lgamma__cuda - - func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) variants: method dispatch: @@ -5989,8 +5983,12 @@ - func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: _lgamma_out_cpu - CUDA: _lgamma_out_cuda + CPU, CUDA: lgamma_out + +- func: lgamma_(Tensor(a!) self) -> Tensor(a!) + variants: method + dispatch: + CPU, CUDA: lgamma_ - func: lgamma(Tensor self) -> Tensor variants: method, function diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 731cb79e031f..77b9ae978158 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -147,6 +147,10 @@ class TORCH_API Tensor { return impl_; } + c10::intrusive_ptr unsafeReleaseIntrusivePtr() { + return std::move(impl_); + } + bool defined() const { return impl_; } diff --git a/aten/src/THC/THCTensorMasked.cuh b/aten/src/THC/THCTensorMasked.cuh index 88f2d78b698d..4e696ba392ce 100644 --- a/aten/src/THC/THCTensorMasked.cuh +++ b/aten/src/THC/THCTensorMasked.cuh @@ -25,22 +25,6 @@ struct TensorMaskedFillOp { T value; }; -template -struct TensorMaskedCopyOp { - TensorMaskedCopyOp(T* s) : in(s) {} - - __device__ inline void operator()(T* out, - MaskT* mask, - MaskPrefixSumT* maskPrefixSum) { - if (*mask) { - *out = in[*maskPrefixSum]; - } - } - - // Where we are copying from - T* in; -}; - template struct TensorMaskedSelectOp { TensorMaskedSelectOp(T* t) : out(t) {} diff --git a/aten/src/THC/generic/THCTensorMasked.cu b/aten/src/THC/generic/THCTensorMasked.cu index 4e93ac260e42..4a3c93241aec 100644 --- a/aten/src/THC/generic/THCTensorMasked.cu +++ b/aten/src/THC/generic/THCTensorMasked.cu @@ -47,145 +47,4 @@ void THCTensor_(maskedFillByte)(THCState* state, THCudaByteTensor_free(state, maskCuda); } -void THCTensor_(maskedCopy)(THCState* state, - THCTensor *tensor, THCudaByteTensor *mask, THCTensor *src) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, tensor, src, mask)); - ptrdiff_t maskSize = THCudaByteTensor_nElement(state, mask); - ptrdiff_t tensorSize = THCTensor_(nElement)(state, tensor); - ptrdiff_t srcSize = THCTensor_(nElement)(state, src); - - // `mask` and `tensor` must have the same number of elements - THArgCheck(maskSize == tensorSize, 2, - "mask and tensor must have the same number of elements"); - - // Determine our output size - int64_t totalElements = THTensor_wrap(mask).sum().item(); - - // The number of `1` elements present in the mask must be <= the - // number of elements available in `src` - if (totalElements > srcSize) { - THArgCheck(false, 2, "source nElements must be == mask `1` elements"); - } - - // FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed - // iterator prefix sums? Convert `mask` to the same datatype as what - // we're accumulating the prefix sum in (int64_t) to get around it - THCudaLongTensor* maskLong = THCudaLongTensor_new(state); - at::IntArrayRef maskSizes = mask->sizes(); - THCudaLongTensor_resize(state, maskLong, maskSizes, {}); - THCTensor_(copy)(state, maskLong, mask); - - // Use a prefix sum to determine the output locations of the masked elements - THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state); - THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, {}); - - THCThrustAllocator thrustAlloc(state); - thrust::device_ptr - maskData(THCudaLongTensor_data(state, maskLong)); - thrust::device_ptr - maskPrefixSumData(THCudaLongTensor_data(state, maskPrefixSum)); - - thrust::exclusive_scan( -#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(c10::cuda::getCurrentCUDAStream()), -#endif - maskData, - maskData + THCudaLongTensor_nElement(state, maskLong), - maskPrefixSumData); - - // We are getting elements from `src` based on an offset from - // `maskPrefixSum`, so that should be made contiguous too - THCTensor* contigSrc = THCTensor_(newContiguous)(state, src); - - // update `tensor` where `mask` == 1 but pull from `src` at - // maskPrefixSum - bool status = THC_pointwiseApply3( - state, tensor, mask, maskPrefixSum, - TensorMaskedCopyOp( - THCTensor_(data)(state, contigSrc))); - - THCTensor_(free)(state, contigSrc); - THCudaLongTensor_free(state, maskLong); - THCudaLongTensor_free(state, maskPrefixSum); - - THArgCheck(status, 2, CUTORCH_DIM_WARNING); - THCudaCheck(cudaGetLastError()); -} - -void THCTensor_(maskedCopyBool)(THCState* state, - THCTensor *tensor, THCudaBoolTensor *mask, THCTensor *src) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, tensor, src, mask)); - ptrdiff_t maskSize = THCudaBoolTensor_nElement(state, mask); - ptrdiff_t tensorSize = THCTensor_(nElement)(state, tensor); - ptrdiff_t srcSize = THCTensor_(nElement)(state, src); - - // `mask` and `tensor` must have the same number of elements - THArgCheck(maskSize == tensorSize, 2, - "mask and tensor must have the same number of elements"); - - // Determine our output size - int64_t totalElements = THTensor_wrap(mask).sum().item(); - - // The number of `1` elements present in the mask must be <= the - // number of elements available in `src` - if (totalElements > srcSize) { - THArgCheck(false, 2, "source nElements must be == mask `1` elements"); - } - - // FIXME: there appears to be a bug in Thrust (CUDA 7.0) for mixed - // iterator prefix sums? Convert `mask` to the same datatype as what - // we're accumulating the prefix sum in (int64_t) to get around it - THCudaLongTensor* maskLong = THCudaLongTensor_new(state); - at::IntArrayRef maskSizes = mask->sizes(); - THCudaLongTensor_resize(state, maskLong, maskSizes, {}); - THCTensor_(copy)(state, maskLong, mask); - - // Use a prefix sum to determine the output locations of the masked elements - THCudaLongTensor* maskPrefixSum = THCudaLongTensor_new(state); - THCudaLongTensor_resize(state, maskPrefixSum, maskSizes, {}); - - THCThrustAllocator thrustAlloc(state); - thrust::device_ptr - maskData(THCudaLongTensor_data(state, maskLong)); - thrust::device_ptr - maskPrefixSumData(THCudaLongTensor_data(state, maskPrefixSum)); - - thrust::exclusive_scan( -#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(c10::cuda::getCurrentCUDAStream()), -#endif - maskData, - maskData + THCudaLongTensor_nElement(state, maskLong), - maskPrefixSumData); - - // We are getting elements from `src` based on an offset from - // `maskPrefixSum`, so that should be made contiguous too - THCTensor* contigSrc = THCTensor_(newContiguous)(state, src); - - // update `tensor` where `mask` == 1 but pull from `src` at - // maskPrefixSum - bool status = THC_pointwiseApply3( - state, tensor, mask, maskPrefixSum, - TensorMaskedCopyOp( - THCTensor_(data)(state, contigSrc))); - - THCTensor_(free)(state, contigSrc); - THCudaLongTensor_free(state, maskLong); - THCudaLongTensor_free(state, maskPrefixSum); - - THArgCheck(status, 2, CUTORCH_DIM_WARNING); - THCudaCheck(cudaGetLastError()); -} - -void THCTensor_(maskedCopyByte)(THCState* state, - THCTensor *tensor, THByteTensor *mask, THCTensor *src) { - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); - THCudaByteTensor* maskCuda = THTensor_wrap(mask).cuda().unsafeReleaseTensorImpl(); - THCTensor_(copy)(state, maskCuda, mask); - THCTensor_(maskedCopy)(state, tensor, maskCuda, src); - THCudaByteTensor_free(state, maskCuda); -} - #endif diff --git a/aten/src/THC/generic/THCTensorMasked.h b/aten/src/THC/generic/THCTensorMasked.h index 87e95a344973..16f4a262c8a0 100644 --- a/aten/src/THC/generic/THCTensorMasked.h +++ b/aten/src/THC/generic/THCTensorMasked.h @@ -21,23 +21,4 @@ TORCH_CUDA_CU_API void THCTensor_(maskedFillByte)( THByteTensor* mask, scalar_t value); -TORCH_CUDA_CU_API void THCTensor_(maskedCopy)( - THCState* state, - THCTensor* tensor, - THCudaByteTensor* mask, - THCTensor* src); - -TORCH_CUDA_CU_API void THCTensor_(maskedCopyBool)( - THCState* state, - THCTensor* tensor, - THCudaBoolTensor* mask, - THCTensor* src); - -// FIXME: remove now that we have THCudaByteTensor? -TORCH_CUDA_CU_API void THCTensor_(maskedCopyByte)( - THCState* state, - THCTensor* tensor, - THByteTensor* mask, - THCTensor* src); - #endif diff --git a/benchmarks/cpp/tensorexpr/bench_reduce.cpp b/benchmarks/cpp/tensorexpr/bench_reduce.cpp index cd467d74162e..06bc9b055176 100644 --- a/benchmarks/cpp/tensorexpr/bench_reduce.cpp +++ b/benchmarks/cpp/tensorexpr/bench_reduce.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -365,7 +366,8 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV1)(benchmark::State& state) { te::For* mi = loops[1]; // TODO: rfactor works on the untransformed var set. This is a problem since we need to // look for the loop after Split to rfactor. - loop.rfactor(BT->body(), mi->var()); + auto bt_body = te::NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(bt_body, mi->var()); } loop.prepareForCodegen(); @@ -411,7 +413,8 @@ BENCHMARK_DEFINE_F(Reduce1D, TeRfactorV2)(benchmark::State& state) { TORCH_CHECK(loops.size() == 2); te::For* mo = loops[0]; te::For* mi = loops[1]; - loop.rfactor(BT->body(), mi->var()); + auto bt_body = te::NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(bt_body, mi->var()); } { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 62f9e8be3e4c..1b1656d1fffa 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1484,6 +1484,7 @@ if(BUILD_PYTHON) # ---[ Python. if(BUILD_CAFFE2) add_library(caffe2_pybind11_state MODULE ${Caffe2_CPU_PYTHON_SRCS}) + target_compile_options(caffe2_pybind11_state PRIVATE "-DUSE_NUMPY") if(NOT MSVC) set_target_properties(caffe2_pybind11_state PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") endif() @@ -1514,6 +1515,7 @@ if(BUILD_PYTHON) if(USE_CUDA) add_library(caffe2_pybind11_state_gpu MODULE ${Caffe2_GPU_PYTHON_SRCS}) + target_compile_options(caffe2_pybind11_state_gpu PRIVATE "-DUSE_NUMPY") if(NOT MSVC) set_target_properties(caffe2_pybind11_state_gpu PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") endif() @@ -1542,6 +1544,7 @@ if(BUILD_PYTHON) if(USE_ROCM) add_library(caffe2_pybind11_state_hip MODULE ${Caffe2_HIP_PYTHON_SRCS}) + target_compile_options(caffe2_pybind11_state_hip PRIVATE "-DUSE_NUMPY") if(NOT MSVC) target_compile_options(caffe2_pybind11_state_hip PRIVATE ${HIP_CXX_FLAGS} -fvisibility=hidden) endif() diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index dd9f9902be1f..bd9a447b879d 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -44,10 +44,6 @@ static_assert( #cmakedefine CAFFE2_USE_NVTX #cmakedefine CAFFE2_USE_TRT -#ifndef USE_NUMPY -#cmakedefine USE_NUMPY -#endif - #ifndef EIGEN_MPL2_ONLY #cmakedefine EIGEN_MPL2_ONLY #endif diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e138a86c61da..75afbb8b6cf4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1347,7 +1347,6 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) set(TP_ENABLE_CUDA_IPC ON CACHE BOOL "" FORCE) endif() set(TP_BUILD_LIBUV ON CACHE BOOL "" FORCE) - set(TP_ENABLE_SHM OFF CACHE BOOL "" FORCE) set(TP_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe) @@ -1851,4 +1850,3 @@ if(USE_KINETO) set(USE_KINETO OFF) endif() endif() - diff --git a/scripts/onnx/test.sh b/scripts/onnx/test.sh index 3432ea434928..5e9cfa936064 100755 --- a/scripts/onnx/test.sh +++ b/scripts/onnx/test.sh @@ -58,6 +58,7 @@ pytest "${args[@]}" \ --ignore "$top_dir/test/onnx/test_utility_funs.py" \ --ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \ --ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \ + --ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py" \ "${test_paths[@]}" # onnxruntime only support py3 diff --git a/test/cpp/rpc/e2e_test_base.h b/test/cpp/rpc/e2e_test_base.h index cea5079b1a4e..6526f8795c19 100644 --- a/test/cpp/rpc/e2e_test_base.h +++ b/test/cpp/rpc/e2e_test_base.h @@ -40,6 +40,14 @@ class TestE2EBase : public ::testing::Test { RpcAgent::setCurrentRpcAgent(rpcAgent); std::shared_ptr typeResolver = std::make_shared([&](const c10::QualifiedName& qn) { + // For Dict that is used for device map. + auto pos = qn.name().find("Dict"); + if (pos != std::string::npos) { + return c10::StrongTypePtr( + nullptr, + c10::DictType::create( + c10::IntType::create(), c10::IntType::create())); + } return c10::StrongTypePtr( nullptr, c10::TensorType::create(at::Tensor())); }); diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp index 7afb839dc7e0..343c965c294c 100644 --- a/test/cpp/tensorexpr/test_llvm.cpp +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -1,17 +1,17 @@ #ifdef TORCH_ENABLE_LLVM #include -#include "test/cpp/tensorexpr/test_base.h" - -#include "test/cpp/tensorexpr/padded_buffer.h" -#include "test/cpp/tensorexpr/test_utils.h" -#include "torch/csrc/jit/tensorexpr/eval.h" -#include "torch/csrc/jit/tensorexpr/ir.h" -#include "torch/csrc/jit/tensorexpr/ir_printer.h" -#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" -#include "torch/csrc/jit/tensorexpr/llvm_codegen.h" -#include "torch/csrc/jit/tensorexpr/loopnest.h" -#include "torch/csrc/jit/tensorexpr/tensor.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -1481,7 +1481,8 @@ TEST(LLVM, RFactorReduction) { loops = loop.getLoopStmtsFor(b); loop_m = loops.at(2); loop_n = loops.at(1); - loop.rfactor(b->body(), loop_n->var(), loop_n->body()); + auto b_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(b_body, loop_n->var(), loop_n->body()); loop.prepareForCodegen(); Stmt* s = loop.root_stmt(); @@ -1522,7 +1523,8 @@ TEST(LLVM, RFactorVectorizedReduction) { For* loop_k = loops.at(0); For* loop_m = loops.at(1); For* loop_n = loops.at(2); - loopnest.rfactor(b->body(), loop_n->var()); + auto b_body = NodeFinder::find(loopnest.root_stmt())[0]; + loopnest.rfactor(b_body, loop_n->var()); loops = NodeFinder::find(loopnest.root_stmt()); loop_k = loops.at(0); diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp index 89ad7eb2aecb..6f44b5e2b033 100644 --- a/test/cpp/tensorexpr/test_loopnest.cpp +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -826,7 +826,7 @@ TEST(LoopNest, ScheduleInlineSimple) { }); LoopNest l1({y}); - LoopNest l2({y}); + LoopNest l2(l1); l2.computeInline(x->buf()); l1.prepareForCodegen(); @@ -1156,7 +1156,7 @@ TEST(LoopNest, ScheduleInlineIntrinsics) { } LoopNest l1({y}); - LoopNest l2({y}); + LoopNest l2(l1); l2.computeInline(x->buf()); l1.prepareForCodegen(); @@ -1228,7 +1228,6 @@ TEST(LoopNest, ScheduleSplitAThenInline) { return a->call(j + ExprHandle(8)); }); - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1247,7 +1246,6 @@ TEST(LoopNest, ScheduleSplitBThenInline) { return a->call(j + ExprHandle(8)); }); - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1275,8 +1273,6 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) { Tensor* b = Compute("b", {{2, "j"}}, [&](const VarHandle& j) { return a->call(j + ExprHandle(8)); }); - - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1296,7 +1292,6 @@ TEST(LoopNest, ScheduleInlineThenSplit) { return a->call(j + ExprHandle(8)); }); - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1325,7 +1320,6 @@ TEST(LoopNest, ScheduleSplitInlineThenSplit) { return a->call(j + ExprHandle(8)); }); - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1357,7 +1351,6 @@ TEST(LoopNest, ScheduleSplitInlineSimplify) { return a->call(j) - ExprHandle(1); }); - LoopNest loop({b}); For* i_outer; For* i_inner; @@ -1714,10 +1707,11 @@ TEST(LoopNest, LoopNestComputeAt_2) { c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); } } + LoopNest orig_loopnest({c}); { // First let's try to compute P at axis cy (the outer loop) - LoopNest l({c}); + LoopNest l(orig_loopnest); std::vector loops = l.getLoopStmtsFor(c); l.computeAt(l.getLoopBodyFor(p), loops[0]); l.prepareForCodegen(); @@ -1748,7 +1742,7 @@ TEST(LoopNest, LoopNestComputeAt_2) { } { // Now let's try to compute P at axis cx (the inner loop) - LoopNest l({c}); + LoopNest l(orig_loopnest); std::vector loops = l.getLoopStmtsFor(c); l.computeAt(l.getLoopBodyFor(p), loops[1]); l.prepareForCodegen(); @@ -1823,9 +1817,10 @@ TEST(LoopNest, LoopNestComputeAt_3) { } } + LoopNest orig_loopnest({D}); { // First let's try to compute A at axis dy (the outer loop) - LoopNest l({D}); + LoopNest l(orig_loopnest); std::vector loops = l.getLoopStmtsFor(D); l.computeAt(l.getLoopBodyFor(A), loops[0]); l.prepareForCodegen(); @@ -1861,7 +1856,7 @@ TEST(LoopNest, LoopNestComputeAt_3) { } { // Now let's try to compute A at axis dx (the inner loop) - LoopNest l({D}); + LoopNest l(orig_loopnest); std::vector loops = l.getLoopStmtsFor(D); l.computeAt(l.getLoopBodyFor(A), loops[1]); l.prepareForCodegen(); @@ -1897,10 +1892,6 @@ TEST(LoopNest, LoopNestComputeAt_3) { } } -TEST(LoopNest, LoopNestComputeAt_4) { - // TODO: Verify that computeAt works with reduction axis -} - class LoopOrderHelper : public IRVisitor { std::stringstream ordering; @@ -3566,7 +3557,7 @@ TEST(LoopNest, DeadStoreElimination) { Stmt* stmt = Block::make({stmt1}); // Will eliminate if not used by an output. - LoopNest loop(stmt, {f.node()}, {}, {}); + LoopNest loop(stmt, {f.node()}, {}); loop.eliminateDeadStores(); std::ostringstream oss; @@ -3580,7 +3571,7 @@ TEST(LoopNest, DeadStoreElimination) { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); // But won't eliminate if used by different outputs. - LoopNest loop2(stmt, {f.node(), g.node()}, {}, {}); + LoopNest loop2(stmt, {f.node(), g.node()}, {}); loop2.eliminateDeadStores(); oss.clear(); @@ -3621,7 +3612,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { // Will eliminate the write to g, but not f since it used by the producer of // h. - LoopNest loop(stmt, {h.node()}, {}, {}); + LoopNest loop(stmt, {h.node()}, {}); loop.eliminateDeadStores(); std::ostringstream oss; @@ -3636,7 +3627,7 @@ TEST(LoopNest, DeadStoreEliminationWithIntermediates) { torch::jit::testing::FileCheck().run(expected_ir, oss.str()); // Sanity check won't eliminate if g is an output. - LoopNest loop2(stmt, {h.node(), g.node()}, {}, {}); + LoopNest loop2(stmt, {h.node(), g.node()}, {}); loop2.eliminateDeadStores(); oss.clear(); @@ -3668,7 +3659,7 @@ TEST(LoopNest, CompoundTensorSimple) { auto outer_for2 = For::make(x, 0, 10, inner_for2); Block* body = Block::make({outer_for1, outer_for2}); - Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + Tensor* A = new Tensor(a_buf.node(), body); LoopNest l({A}); l.prepareForCodegen(); @@ -3707,7 +3698,7 @@ TEST(LoopNest, CompoundTensorUsed) { auto outer_for2 = For::make(x, 0, 10, inner_for2); Block* body = Block::make({outer_for1, outer_for2}); - Tensor* A = new CompoundTensor(a_buf.node(), {i.node(), j.node()}, body); + Tensor* A = new Tensor(a_buf.node(), body); Tensor* B = Compute( "B", {{10, "i"}, {3, "j"}}, [&](const VarHandle& i, const VarHandle& j) { return A->call(i, j + 1) + A->call(i, j + 2); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index f69217df9bde..9c538741d9f4 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -259,9 +259,9 @@ TEST(Reductions, ReduceMax) { Tensor* m2d = Reduce("max", {{2, "n"}}, Maximum(kFloat), in2_, {{5, "m"}}); - loop = LoopNest({m2d}); - loop.prepareForCodegen(); - s = loop.root_stmt(); + LoopNest loop2({m2d}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg2(s, {in2_, m2d}); @@ -372,9 +372,9 @@ TEST(Reductions, ReduceAnyAll) { }, {{10, "j"}}); - loop = LoopNest({allGreaterThan}); - loop.prepareForCodegen(); - s = loop.root_stmt(); + LoopNest loop2({allGreaterThan}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); s = IRSimplifier::simplify(s); SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); @@ -699,7 +699,8 @@ TEST(Reductions, ReduceRfactor) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(1)->var(); - loop.rfactor(c->body(), v); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -734,7 +735,8 @@ TEST(Reductions, Reduce3DRfactorInternal) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(1)->var(); - loop.rfactor(c->body(), v); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -769,7 +771,8 @@ TEST(Reductions, Reduce3DRfactorInner) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(2)->var(); - loop.rfactor(c->body(), v); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -804,7 +807,8 @@ TEST(Reductions, Reduce3DRfactorOuter) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(0)->var(); - loop.rfactor(c->body(), v); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -841,7 +845,8 @@ TEST(Reductions, Reduce3DRfactorWithOuter) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(3)->var(); - loop.rfactor(c->body(), v); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -870,12 +875,13 @@ TEST(Reductions, Reduce3DRfactorRepeated) { } Tensor* c = Reduce("sum", {}, Sum(), b, {{m, "m"}, {n, "n"}, {k, "k"}}); + LoopNest orig_loopnest({c}); for (int rVar1 = 0; rVar1 < 3; ++rVar1) { for (int rVar2 = 0; rVar2 < 2; ++rVar2) { std::vector out(1, -1.f); - LoopNest loop({c}); + LoopNest loop(orig_loopnest); auto reduces = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(reduces.size(), 1); auto v1 = reduces[0]->reduce_args()[rVar1]; @@ -921,7 +927,8 @@ TEST(Reductions, ReduceRfactorInsertionPoint) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(0)->var(); - loop.rfactor(c->body(), v, loops.at(0)->body()); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v, loops.at(0)->body()); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -956,7 +963,8 @@ TEST(Reductions, Reduce3DRfactorInsertionPoint) { LoopNest loop({c}); std::vector loops = loop.getLoopStmtsFor(c); auto v = loops.at(1)->var(); - loop.rfactor(c->body(), v, loops.at(1)->body()); + auto c_body = NodeFinder::find(loop.root_stmt())[0]; + loop.rfactor(c_body, v, loops.at(1)->body()); auto rc = NodeFinder::find(loop.root_stmt()); ASSERT_EQ(rc.size(), 2); loop.prepareForCodegen(); @@ -985,13 +993,12 @@ TEST(Reductions, ReduceRepeatedInternalRfactor) { in_, {{2, "a"}, {3, "b"}, {4, "c"}, {5, "d"}, {6, "e"}}); LoopNest refloop({c}); + LoopNest loop(refloop); refloop.prepareForCodegen(); SimpleIREvaluator ref_cg( IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); ref_cg.call({in, ref}); - LoopNest loop({c}); - // rfactor out "c". auto reduces = NodeFinder::find(loop.root_stmt()); loop.rfactor(reduces[0], reduces[0]->reduce_args()[3]); @@ -1373,7 +1380,7 @@ TEST(Reductions, ReduceInlineConsumer) { } LoopNest l1({y}); - LoopNest l2({y}); + LoopNest l2(l1); l2.computeInline(x->buf()); l1.prepareForCodegen(); @@ -1431,7 +1438,7 @@ TEST(Reductions, ReduceInlineReducerInternal) { } LoopNest l1({y}); - LoopNest l2({y}); + LoopNest l2(l1); l2.computeInline(x->buf()); l1.prepareForCodegen(); @@ -1863,11 +1870,11 @@ TEST(Reductions, ReductionVectorize) { Tensor* tensor = Reduce("sum", {{8, "m"}}, Sum(), in, {{8, "n"}}); LoopNest l_before({tensor}); + LoopNest l(l_before); l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); cg_before.call({in_, out_before}); - LoopNest l({tensor}); l.vectorize(l.getLoopStmtsFor(tensor)[0]); Stmt* s = l.root_stmt(); @@ -1923,11 +1930,11 @@ TEST(Reductions, ReductionVectorizeRfactor) { Tensor* tensor = Reduce("sum", {}, Sum(), in, {{8, "m"}, {8, "n"}}); LoopNest l_before({tensor}); + LoopNest l(l_before); l_before.prepareForCodegen(); SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); cg_before.call({in_, out_before}); - LoopNest l({tensor}); ASSERT_THROWS_WITH( l.vectorize(l.getLoopStmtsFor(tensor)[1]), "reduction axis"); @@ -1935,7 +1942,8 @@ TEST(Reductions, ReductionVectorizeRfactor) { // loop. std::vector loops = l.getLoopStmtsFor(tensor); auto v = loops.at(1)->var(); - l.rfactor(tensor->body(), v); + auto tensor_body = NodeFinder::find(l.root_stmt())[0]; + l.rfactor(tensor_body, v); loops = NodeFinder::find(l.root_stmt()); l.vectorize(loops[2]); diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp index 31e05549186e..b935e5d2b6b6 100644 --- a/test/cpp/tensorexpr/tutorial.cpp +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -118,33 +118,53 @@ int main(int argc, char* argv[]) { std::cout << "*** Tensors, Functions, and Placeholders ***" << std::endl; { - // A tensor computation is represented by objects of Tensor class and + // A tensor computation is represented by Tensor class objects and // consists of the following pieces: // - domain, which is specified by a Buf expression - // - an expression (or several expressions if we want to perform several - // independent computations over the same domain) for its elements, as a - // function of indices - // - // TODO: Update this section once Tensor/Function cleanup is done + // - a tensor statement, specified by a Stmt object, that computation to + // be performed in this domain + + // Let's start with defining a domain. We do this by creating a Buf object. + + // First, let's specify the sizes: std::vector dims = { new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate // and represents an integer constant - // Next we need to create arguments. The arguments are Vars, and they play - // role of placeholders. The computation that the tensor would describe - // would use these arguments. + // Now we can create a Buf object by providing a name, dimensions, and a + // data type of the elements: + const Buf* buf = new Buf("X", dims, kInt); + + // Next we need to spefify the computation. We can do that by either + // constructing a complete tensor statement for it (statements are + // examined in details in subsequent section), or by using a convenience + // method where we could specify axis and an element expression for the + // computation. In the latter case a corresponding statement would be + // constructed automatically. + + // Let's define two variables, i and j - they will be axis in our + // computation. const Var* i = new Var("i", kInt); const Var* j = new Var("j", kInt); std::vector args = {i, j}; // Now we can define the body of the tensor computation using these - // arguments. + // variables. What this means is that values in our tensor are: + // X[i, j] = i * j Expr* body = new Mul(i, j); // Finally, we pass all these pieces together to Tensor constructor: - Tensor* X = new Tensor("X", dims, args, body); + Tensor* X = new Tensor(buf, args, body); std::cout << "Tensor computation: " << *X << std::endl; - // Prints: Tensor computation: Tensor X(i[64], j[32]) = i * j + // Prints: + // Tensor computation: Tensor X[64, 32]: + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // X[i, j] = i * j; + // } + // } + + // TODO: Add an example of constructing a Tensor with a complete Stmt. // Similarly to how we provide a more convenient way of using handles for // constructing Exprs, Tensors also have a more convenient API for @@ -155,11 +175,17 @@ int main(int argc, char* argv[]) { {{64, "i"}, {32, "j"}}, [](const VarHandle& i, const VarHandle& j) { return i / j; }); std::cout << "Tensor computation: " << *Z << std::endl; - // Prints: Tensor computation: Tensor Z(i[64], j[32]) = i / j + // Prints: + // Tensor computation: Tensor Z[64, 32]: + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // Z[i, j] = i / j; + // } + // } // Tensors might access other tensors and external placeholders in their // expressions. It can be done like so: - Placeholder P("P", kFloat, {64, 32}); + Placeholder P("P", kInt, {64, 32}); Tensor* R = Compute( "R", {{64, "i"}, {32, "j"}}, @@ -167,7 +193,13 @@ int main(int argc, char* argv[]) { return Z->call(i, j) * P.load(i, j); }); std::cout << "Tensor computation: " << *R << std::endl; - // Prints: Tensor computation: Tensor R(i[64], j[32]) = Z(i, j) * P[i, j] + // Prints: + // Tensor computation: Tensor R[64, 32]: + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // R[i, j] = (Z(i, j)) * (P[i, j]); + // } + // } // Placeholders could be thought of as external tensors, i.e. tensors for // which we don't have the element expression. In other words, for `Tensor` @@ -211,8 +243,19 @@ int main(int argc, char* argv[]) { std::cout << "Tensor computation X: " << *X << "Tensor computation Y: " << *Y << std::endl; // Prints: - // Tensor computation X: Tensor X(i[64], j[32]) = (A[i, j]) + (B[i, j]) - // Tensor computation Y: Tensor Y(i[64], j[32]) = sigmoid(X(i, j)) + // Tensor computation X: Tensor X[64, 32]: + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // X[i, j] = (A[i, j]) + (B[i, j]); + // } + // } + + // Tensor computation Y: Tensor Y[64, 32]: + // for (int i = 0; i < 64; i++) { + // for (int j = 0; j < 32; j++) { + // Y[i, j] = sigmoid(X(i, j)); + // } + // } // Creating a loop nest is as quite simple, we just need to specify what are // the output tensors in our computation and LoopNest object will diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 4d3d73e5f7c7..3a5881f365c1 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -11,7 +11,7 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, make_global import torch.testing._internal.jit_utils from torch.testing._internal.common_utils import IS_SANDCASTLE from typing import List, Tuple, Iterable, Optional, Dict @@ -143,12 +143,12 @@ def __init__(self, x): self.attr = x def test_class_type_as_param(self): - global FooTest # see [local resolution in python] - class FooTest(object): # noqa: B903 def __init__(self, x): self.attr = x + make_global(FooTest) # see [local resolution in python] + @torch.jit.script def fn(foo: FooTest) -> torch.Tensor: return foo.attr @@ -279,13 +279,13 @@ def forward(self, a): self.assertEqual(2 * input, output) def test_python_interop(self): - global Foo # see [local resolution in python] - class Foo(object): # noqa: B903 def __init__(self, x, y): self.x = x self.y = y + make_global(Foo) # see [local resolution in python] + @torch.jit.script def use_foo(foo: Foo) -> Foo: return foo @@ -305,13 +305,13 @@ def use_foo(foo: Foo) -> Foo: self.assertEqual(y, f2.y) def test_class_specialization(self): - global Foo # see [local resolution in python] - class Foo(object): # noqa: B903 def __init__(self, x, y): self.x = x self.y = y + make_global(Foo) # see [local resolution in python] + def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: a, b = tup return foo.x + foo2.y + a.x + b.y @@ -329,8 +329,6 @@ def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: FileCheck().check_count("prim::GetAttr", 4).run(graphstr) def test_class_sorting(self): - global Foo # see [local resolution in python] - class Foo(object): # noqa: B903 def __init__(self, x: int) -> None: self.x = x @@ -342,6 +340,8 @@ def __lt__(self, other) -> bool: def getVal(self): return self.x + make_global(Foo) # see [local resolution in python] + def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]: li_sorted = sorted(li) ret_sorted = torch.jit.annotate(List[int], []) @@ -500,8 +500,6 @@ def forward(self, a): self.assertEqual(3 * input, output) def test_interface(self): - global Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2 - @torch.jit.script class Foo(object): def __init__(self): @@ -571,6 +569,8 @@ def one(self, x, y): def two(self, x: int) -> int: return 3 + make_global(Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2) + def use_them(x): a = Foo() b = Bar() @@ -652,8 +652,6 @@ def __init__(self): # NamedTuple inheritance errors def test_overloaded_fn(self): - global Foo, MyClass # see [local resolution in python] - @torch.jit.script class Foo(object): def __init__(self, x): @@ -673,6 +671,8 @@ def test_overload(): a = Foo(torch.ones([3, 3])) return len(a), -a * torch.zeros([3, 3]) + make_global(Foo) # see [local resolution in python] + self.checkScript(test_overload, ()) # unary ops tested above @@ -737,6 +737,8 @@ def __call__(self, val: int) -> int: return self.x * val * 3 + make_global(Foo) # see [local resolution in python] + def add(): return MyClass(4) + 3 def sub(): # noqa: E306 @@ -787,8 +789,6 @@ def test(): return Foo(torch.tensor(1)) + Foo(torch.tensor(1)) def test_cast_overloads(self): - global Foo # see [local resolution in python] - @torch.jit.script class Foo(object): def __init__(self, val: float) -> None: @@ -806,6 +806,8 @@ def __bool__(self): def __str__(self): return str(self.val) + make_global(Foo) # see [local resolution in python] + def test(foo: Foo) -> Tuple[int, float, bool]: if foo: pass @@ -914,8 +916,6 @@ def forward(self, x): self.assertEqual(m.w, m_loaded.w) def test_py_class_to_ivalue_missing_attribute(self): - global Foo # see [local resolution in python] - class Foo(object): i : int f : float @@ -924,6 +924,8 @@ def __init__(self, i : int, f : float): self.i = i self.f = f + make_global(Foo) # see [local resolution in python] + @torch.jit.script def test_fn(x : Foo) -> float: return x.i + x.f @@ -1132,8 +1134,6 @@ def test_staticmethod(self): """ Test static methods on class types. """ - global ClassWithStaticMethod - @torch.jit.script class ClassWithStaticMethod: def __init__(self, a: int, b: int): @@ -1164,14 +1164,14 @@ def create_from(a: int, b: int) -> 'ClassWithStaticMethod': def test_function(a: int, b: int) -> 'ClassWithStaticMethod': return ClassWithStaticMethod.create_from(a, b) + make_global(ClassWithStaticMethod) + self.checkScript(test_function, (1, 2)) def test_classmethod(self): """ Test classmethods on class types. """ - global ClassWithClassMethod - @torch.jit.script class ClassWithClassMethod: def __init__(self, a: int): @@ -1184,6 +1184,8 @@ def __eq__(self, other: 'ClassWithClassMethod'): def create(cls, a: int) -> 'ClassWithClassMethod': return cls(a) + make_global(ClassWithClassMethod) + def test_function(a: int) -> 'ClassWithClassMethod': x = ClassWithClassMethod(a) # Support calling classmethod with an instance diff --git a/test/jit/test_enum.py b/test/jit/test_enum.py index b39732d0e9bc..1a5f79d6e3ca 100644 --- a/test/jit/test_enum.py +++ b/test/jit/test_enum.py @@ -9,7 +9,7 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, make_global if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" @@ -18,24 +18,20 @@ class TestEnum(JitTestCase): def test_enum_value_types(self): - global IntEnum - class IntEnum(Enum): FOO = 1 BAR = 2 - global FloatEnum - class FloatEnum(Enum): FOO = 1.2 BAR = 2.3 - global StringEnum - class StringEnum(Enum): FOO = "foo as in foo bar" BAR = "bar as in foo bar" + make_global(IntEnum, FloatEnum, StringEnum) + @torch.jit.script def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): return (a.name, b.name, c.name) @@ -46,12 +42,12 @@ def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): .check("StringEnum") \ .run(str(supported_enum_types.graph)) - global TensorEnum - class TensorEnum(Enum): FOO = torch.tensor(0) BAR = torch.tensor(1) + make_global(TensorEnum) + def unsupported_enum_types(a: TensorEnum): return a.name @@ -59,12 +55,12 @@ def unsupported_enum_types(a: TensorEnum): torch.jit.script(unsupported_enum_types) def test_enum_comp(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def enum_comp(x: Color, y: Color) -> bool: return x == y @@ -75,8 +71,6 @@ def enum_comp(x: Color, y: Color) -> bool: self.assertEqual(enum_comp(Color.RED, Color.GREEN), False) def test_enum_comp_diff_classes(self): - global Foo, Bar - class Foo(Enum): ITEM1 = 1 ITEM2 = 2 @@ -85,6 +79,8 @@ class Bar(Enum): ITEM1 = 1 ITEM2 = 2 + make_global(Foo, Bar) + @torch.jit.script def enum_comp(x: Foo) -> bool: return x == Bar.ITEM1 @@ -98,12 +94,12 @@ def enum_comp(x: Foo) -> bool: self.assertEqual(enum_comp(Foo.ITEM1), False) def test_heterogenous_value_type_enum_error(self): - global Color - class Color(Enum): RED = 1 GREEN = "green" + make_global(Color) + def enum_comp(x: Color, y: Color) -> bool: return x == y @@ -111,12 +107,12 @@ def enum_comp(x: Color, y: Color) -> bool: torch.jit.script(enum_comp) def test_enum_name(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def enum_name(x: Color) -> str: return x.name @@ -131,12 +127,12 @@ def enum_name(x: Color) -> str: self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name) def test_enum_value(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def enum_value(x: Color) -> int: return x.value @@ -151,12 +147,12 @@ def enum_value(x: Color) -> int: self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value) def test_enum_as_const(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def enum_const(x: Color) -> bool: return x == Color.RED @@ -171,12 +167,12 @@ def enum_const(x: Color) -> bool: self.assertEqual(enum_const(Color.GREEN), False) def test_non_existent_enum_value(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + def enum_const(x: Color) -> bool: if x == Color.PURPLE: return True @@ -187,12 +183,12 @@ def enum_const(x: Color) -> bool: torch.jit.script(enum_const) def test_enum_ivalue_type(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def is_color_enum(x: Any): return isinstance(x, Color) @@ -207,8 +203,6 @@ def is_color_enum(x: Any): self.assertEqual(is_color_enum(1), False) def test_closed_over_enum_constant(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 @@ -240,8 +234,6 @@ def closed_over_aliased_value(): self.assertEqual(closed_over_aliased_value(), Color.RED.value) def test_enum_as_module_attribute(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 @@ -268,8 +260,6 @@ def forward(self): self.assertEqual(scripted(), Color.RED.value) def test_string_enum_as_module_attribute(self): - global Color - class Color(Enum): RED = "red" GREEN = "green" @@ -282,18 +272,19 @@ def __init__(self, e: Color): def forward(self): return (self.e.name, self.e.value) + make_global(Color) m = TestModule(Color.RED) scripted = torch.jit.script(m) self.assertEqual(scripted(), (Color.RED.name, Color.RED.value)) def test_enum_return(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 + make_global(Color) + @torch.jit.script def return_enum(cond: bool): if cond: @@ -305,8 +296,6 @@ def return_enum(cond: bool): self.assertEqual(return_enum(False), Color.GREEN) def test_enum_module_return(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 @@ -319,6 +308,7 @@ def __init__(self, e: Color): def forward(self): return self.e + make_global(Color) m = TestModule(Color.RED) scripted = torch.jit.script(m) @@ -333,8 +323,6 @@ def forward(self): def test_enum_iterate(self): - global Color - class Color(Enum): RED = 1 GREEN = 2 @@ -347,6 +335,7 @@ def iterate_enum(x: Color): res.append(e.value) return res + make_global(Color) scripted = torch.jit.script(iterate_enum) FileCheck() \ diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 9dff4b0f4549..71f572dedb54 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -1877,6 +1877,24 @@ def forward(self, inputs): tm = torch.jit.trace(m, torch.tensor(1.)) self.assertFalse(hasattr(tm, "submod")) + def test_trace_with_conditional_property(self): + class Net(nn.Module): + def __init__(self, attr=None): + super(Net, self).__init__() + if attr is not None: + self._attr = attr + self.attr_name = '_attr' + + @property + def attr(self): + return getattr(self, self.attr_name) + + def forward(self, x): + return x + + x = torch.ones(1) + torch.jit.trace(Net(), x) + class TestMixTracingScripting(JitTestCase): def test_trace_script(self): diff --git a/test/jit/test_with.py b/test/jit/test_with.py index f958dc46c39a..35ff5b959737 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -4,7 +4,7 @@ from typing import Any, List import torch -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable @@ -29,8 +29,6 @@ def test_with_as(self): Check that with statements that use the 'as' keyword to bind expressions to targets work as expected. """ - global Context - @torch.jit.script class Context(object): """ @@ -50,6 +48,8 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) + make_global(Context) + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" @@ -185,8 +185,6 @@ def test_with_no_as(self): Check that with statements that do not use the 'as' keyword to bind expressions to targets work as expected. """ - global Context - @torch.jit.script class Context(object): """ @@ -206,6 +204,8 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) + make_global(Context) + def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" @@ -341,8 +341,6 @@ def test_with_exceptions(self): Check that exceptions thrown in the bodies of with-statements are handled correctly. """ - global Context - @torch.jit.script class Context(object): """ @@ -362,6 +360,8 @@ def __enter__(self): def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) + make_global(Context) + @torch.jit.script def method_that_raises() -> torch.Tensor: raise Exception("raised exception") diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect index 5355daf4f3ca..67d765831c1b 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale.expect @@ -50,16 +50,16 @@ graph { elem_type: 1 shape { dim { - dim_param: "Upsample4_dim_0" + dim_value: 1 } dim { - dim_param: "Upsample4_dim_1" + dim_value: 2 } dim { - dim_param: "Upsample4_dim_2" + dim_value: 6 } dim { - dim_param: "Upsample4_dim_3" + dim_value: 8 } } } diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect index 5355daf4f3ca..67d765831c1b 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest_scale_default_scale_factor.expect @@ -50,16 +50,16 @@ graph { elem_type: 1 shape { dim { - dim_param: "Upsample4_dim_0" + dim_value: 1 } dim { - dim_param: "Upsample4_dim_1" + dim_value: 2 } dim { - dim_param: "Upsample4_dim_2" + dim_value: 6 } dim { - dim_param: "Upsample4_dim_3" + dim_value: 8 } } } diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index d3d4ebef5f61..0ddfcf1e40ad 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -302,7 +302,6 @@ def forward(self, input): x = torch.tensor([2], dtype=torch.long) self.run_model_test_with_external_data(model, x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because external data format was released with Opset 9. def test_mobilenet_v2_with_external_data(self): model = torchvision.models.mobilenet_v2(pretrained=True) @@ -407,19 +406,7 @@ def run_word_language_model(self, model_name): # Only support CPU version, since tracer is not working in GPU RNN. self.run_test(model, (x, model.hidden)) - @skipIfUnsupportedOpsetVersion([13]) - @skipIfUnsupportedMinOpsetVersion(11) - @disableScriptTest() # Faster RCNN model is not scriptable - def test_faster_rcnn(self): - model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, - max_size=300) - model.eval() - x = torch.randn(2, 3, 200, 300, requires_grad=True) - self.run_test(model, (x,), rtol=1e-3, atol=1e-5) - self.run_test(model, (x,), input_names=["images_tensors"], output_names=["outputs"], - dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, rtol=1e-3, atol=1e-5) - - def get_image_from_url(self, url): + def get_image_from_url(self, url, size=(300, 200)): import os from urllib.parse import urlsplit from urllib import request @@ -434,17 +421,41 @@ def get_image_from_url(self, url): with open(path, 'wb') as f: f.write(data) image = Image.open(path).convert("RGB") - image = image.resize((300, 200), Image.BILINEAR) + + image = image.resize(size, Image.BILINEAR) + to_tensor = transforms.ToTensor() return to_tensor(image) def get_test_images(self): image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" - image = self.get_image_from_url(url=image_url) - images = [image] - return images + image = self.get_image_from_url(url=image_url, size=(100, 320)) + + image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" + image2 = self.get_image_from_url(url=image_url2, size=(250, 380)) + + return [image], [image2] @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() # Faster RCNN model is not scriptable + def test_faster_rcnn(self): + model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, + max_size=300) + model.eval() + x = torch.randn(2, 3, 200, 300, requires_grad=True) + self.run_test(model, (x,), rtol=1e-3, atol=1e-5) + self.run_test(model, (x,), input_names=["images_tensors"], output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, rtol=1e-3, atol=1e-5) + dummy_image = [torch.ones(3, 100, 100) * 0.3] + images, test_images = self.get_test_images() + self.run_test(model, (images,), test_with_inputs=[(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + self.run_test(model, (dummy_image,), test_with_inputs=[(dummy_image,), (images,)], + input_names=["images_tensors"], output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + def test_paste_mask_in_image(self): # disable profiling torch._C._jit_set_profiling_executor(False) @@ -482,11 +493,20 @@ def test_paste_mask_in_image(self): def test_mask_rcnn(self): model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) - images = self.get_test_images() + images, test_images = self.get_test_images() self.run_test(model, (images,), rtol=1e-3, atol=1e-5) self.run_test(model, (images,), input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"], dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], "scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + dummy_image = [torch.ones(3, 100, 100) * 0.3] + self.run_test(model, (images,), test_with_inputs=[(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], + "scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + self.run_test(model, (dummy_image,), test_with_inputs=[(dummy_image,), (images,)], + input_names=["images_tensors"], output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], + "scores": [0], "masks": [0, 1, 2]}, rtol=1e-3, atol=1e-5) def test_heatmaps_to_keypoints(self): # disable profiling @@ -518,29 +538,46 @@ def test_heatmaps_to_keypoints(self): def test_keypoint_rcnn(self): model = torchvision.models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) - images = self.get_test_images() + images, test_images = self.get_test_images() self.run_test(model, (images,), rtol=1e-3, atol=1e-5) self.run_test(model, (images,), input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"], dynamic_axes={"images_tensors": [0, 1, 2]}, rtol=1e-3, atol=1e-5) + dummy_images = [torch.ones(3, 100, 100) * 0.3] + self.run_test(model, (images,), test_with_inputs=[(images,), (test_images,), (dummy_images,)], + input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + rtol=5e-3, atol=1e-5) + self.run_test(model, (dummy_images,), test_with_inputs=[(dummy_images,), (test_images,)], + input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + rtol=5e-3, atol=1e-5) @skipIfUnsupportedOpsetVersion([13]) + @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() + def test_shufflenet_v2_dynamic_axes(self): + model = torchvision.models.shufflenet_v2_x0_5(pretrained=True) + dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) + test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True) + self.run_test(model, (dummy_input,), test_with_inputs=[(dummy_input,), (test_inputs,)], + input_names=["input_images"], output_names=["outputs"], + dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}}, + rtol=1e-3, atol=1e-5) + @disableScriptTest() def test_word_language_model_RNN_TANH(self): self.run_word_language_model("RNN_TANH") - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_RNN_RELU(self): self.run_word_language_model("RNN_RELU") - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_LSTM(self): self.run_word_language_model("LSTM") - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_word_language_model_GRU(self): self.run_word_language_model("GRU") @@ -689,7 +726,7 @@ def forward(self, input): # Without empty optional arguments dictionary x = torch.randn(2, 3) - self.run_test(NoOptionalModel(), (x,), input_names=['input_x']) + self.run_test(NoOptionalModel(), (x,), input_names=['input_x']) # With empty optional arguments dictionary y = torch.randn(2, 3) self.run_test(NoOptionalModel(), (y, {})) @@ -768,7 +805,6 @@ def forward(self, x, y=None, z=None): z = torch.randn(2, 3) self.run_test(Model(), (x, None, z)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_cste_script(self): class MyModel(torch.jit.ScriptModule): @@ -1037,44 +1073,37 @@ def forward(self, x): else: self.run_test(Squeeze(d), x1) - @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_without_no_op(self): x = torch.randn(2, 1, 4) self.squeeze_model_tests(1, x, None) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_dynamic(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(1, x_squeeze, x_noop) - @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_neg_without_no_op(self): x = torch.randn(2, 1, 4) self.squeeze_model_tests(-2, x, None) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_neg(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(-2, x_squeeze, x_noop) - @skipIfUnsupportedOpsetVersion([13]) def test_squeeze_all_dims(self): x_squeeze = torch.randn(2, 1, 4) x_noop = torch.randn(2, 2, 3) self.squeeze_model_tests(None, x_squeeze, x_noop) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_no_op(self): x_noop = torch.randn(2, 1, 4) x_squeeze = torch.randn(2, 2, 1) self.squeeze_model_tests(2, x_noop, x_squeeze) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_squeeze_runtime_dim(self): class Squeeze(torch.nn.Module): @@ -1088,7 +1117,14 @@ def forward(self, d1, d2): self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)]) self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)]) - @skipIfUnsupportedOpsetVersion([13]) + def test_squeeze(self): + class Squeeze(torch.nn.Module): + def forward(self, x): + return torch.squeeze(x, dim=-2) + + x = torch.randn(2, 1, 4) + self.run_test(Squeeze(), x) + def test_unsqueeze(self): class Unsqueeze(torch.nn.Module): def forward(self, x): @@ -1288,7 +1324,6 @@ def forward(self, x, y): y = torch.randn(2, 3, 4) self.run_test(FloorDivModule(), (x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_floordiv(self): class FloordivModule(torch.nn.Module): @@ -1366,7 +1401,6 @@ def forward(self, x, y): y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.double) self.run_test(torch.jit.script(DivModule()), (x, y)) - @skipIfUnsupportedOpsetVersion([13]) def test_slice_trace(self): class MyModule(torch.nn.Module): def forward(self, x): @@ -1375,7 +1409,6 @@ def forward(self, x): x = torch.randn(3) self.run_test(MyModule(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -1384,7 +1417,6 @@ def forward(self, x): x = torch.randn(3, 4, 5) self.run_test(NegSlice(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg_large(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -1393,7 +1425,6 @@ def forward(self, x): x = torch.randn(3, 4, 5, 6, 7) self.run_test(NegSlice(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_slice_neg_large_negone(self): class NegSlice(torch.nn.Module): def forward(self, x): @@ -1402,7 +1433,6 @@ def forward(self, x): x = torch.randn(3, 4, 5, 6, 7) self.run_test(NegSlice(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_slice_with_input_index(self): class InputIndexSlice(torch.nn.Module): @@ -1414,7 +1444,6 @@ def forward(self, x, y): y = torch.rand((22, 256)) self.run_test(InputIndexSlice(), (x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() # scripting tuple/list append def test_slice_dynamic(self): @@ -1433,7 +1462,6 @@ def forward(self, x): dynamic_axes={'input_1': [0, 1, 2], 'output_1': [0, 1, 2]}) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) def test_slice_dynamic_script(self): class DynamicSliceModel(torch.jit.ScriptModule): @@ -1444,7 +1472,6 @@ def forward(self, x): x = torch.rand(1, 2) self.run_test(DynamicSliceModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) def test_slice_dynamic_shape_script(self): class DynamicSliceModel(torch.nn.Module): @@ -1454,7 +1481,6 @@ def forward(self, x): x = torch.rand(1, 2, 3, 4) self.run_test(DynamicSliceModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() # scripting tuple/list append def test_slice_dynamic_to_end(self): @@ -1561,7 +1587,6 @@ def forward(self, end): x = torch.tensor(6.2, dtype=torch.float) self.run_test(ArangeModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_size(self): class SizeModel(torch.nn.Module): @@ -1571,7 +1596,6 @@ def forward(self, input): x = torch.randn(5, 3, 2) self.run_test(SizeModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() # x.stride() not scriptable def test_as_strided(self): @@ -1586,7 +1610,6 @@ def forward(self, x): x = torch.randn(5, 8, 7) self.run_test(Model(), x) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() # Ellipses followed by tensor indexing not scriptable def test_tensor_index_advanced_indexing_ellipsis(self): class MyModel(torch.nn.Module): @@ -1596,7 +1619,6 @@ def forward(self, input): m1 = torch.randn(3, 4, 5, 6, 7) self.run_test(MyModel(), (m1,)) - @skipIfUnsupportedOpsetVersion([13]) def test_tensor_index_advanced_indexing(self): class MyModel(torch.nn.Module): def forward(self, input): @@ -1617,7 +1639,6 @@ def forward(self, input): self.run_test(MyModel(), (m1,)) - @skipIfUnsupportedOpsetVersion([13]) def test_tensor_index_advanced_indexing_consecutive(self): class MyModel(torch.nn.Module): def forward(self, input): @@ -1626,7 +1647,6 @@ def forward(self, input): m1 = torch.randn(3, 4, 5, 6, 7) self.run_test(MyModel(), (m1,)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put(self): class IndexPutModel(torch.nn.Module): @@ -1639,7 +1659,6 @@ def forward(self, x, ind, update): update = torch.ones(4) self.run_test(IndexPutModel(), (x, ind, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_accumulate(self): class IndexPutModel(torch.nn.Module): @@ -1651,7 +1670,6 @@ def forward(self, x, ind, update): update = torch.ones(4) self.run_test(IndexPutModel(), (x, ind, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_slice_index(self): class IndexPutModel(torch.nn.Module): @@ -1726,7 +1744,6 @@ def forward(self, x, update): update = torch.arange(3 * 5).to(torch.float).view(3, 5) self.run_test(IndexPutModel8(), (x, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # Ellipses followed by tensor indexing not scriptable def test_index_put_ellipsis(self): @@ -1748,7 +1765,6 @@ def forward(self, x, update): update = torch.randn(4, 1, 3, 2) self.run_test(IndexPutModel2(), (x, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_index_put_loop(self): @torch.jit.script @@ -1831,7 +1847,6 @@ def forward(self, x, ind, data): data = torch.randn(4) self.run_test(CopyModel4(), (x, ind, data)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # Model not scriptable (output with shape doesn't match the broadcast shape) def test_copy_tracing(self): @@ -1844,7 +1859,6 @@ def forward(self, x, data): update = torch.randn(1, 2) self.run_test(CopyModel(), (x, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_copy_ellipsis(self): class CopyModel(torch.nn.Module): @@ -1860,7 +1874,6 @@ def forward(self, x, update): update = torch.ones(1) self.run_test(CopyModel(), (x, update)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) # TODO: Limited scripting support with ellipsis indexing. # Due to dependency on input tensor rank being known. @@ -1899,7 +1912,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Rand(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_random_dynamic_size(self): class RandN(torch.nn.Module): @@ -2062,12 +2074,10 @@ def _interpolate_tests(self, is_upsample): self._interpolate_script(xi, mode_i, False, is_upsample, True) self._interpolate_script(xi, mode_i, False, is_upsample) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_interpolate_upsample(self): self._interpolate_tests(True) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() @skipIfUnsupportedMinOpsetVersion(9) def test_interpolate_function_substitution(self): @@ -2098,13 +2108,11 @@ def forward(self, x): self.run_test(TracingModule(), (x,)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) @disableScriptTest() def test_interpolate_downsample(self): self._interpolate_tests(False) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_interpolate_no_shape(self): @@ -2120,7 +2128,6 @@ def forward(self, x, y): y = torch.randn(16, 16, requires_grad=True) self.run_test(MyModel(), (x, y)) - @skipIfUnsupportedOpsetVersion([13]) def test_interpolate_adaptive_pooling_error(self): x = torch.randn(1, 2, 6, requires_grad=True) with self.assertRaises(RuntimeError) as cm: @@ -2129,7 +2136,6 @@ def test_interpolate_adaptive_pooling_error(self): with self.assertRaises(RuntimeError) as cm: self._interpolate(x, "area", False, True) - @skipIfUnsupportedOpsetVersion([13]) def test_groupnorm(self): model = torch.nn.GroupNorm(3, 6, 0.002) x = torch.randn(4, 6, 180, 180, 180) @@ -2143,7 +2149,6 @@ def test_groupnorm(self): x = torch.randn(4, 6, 180, 180) self.run_test(model, x) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_groupnorm_noaffine(self): model = torch.nn.GroupNorm(4, 8, 0.002, affine=False) @@ -2158,7 +2163,6 @@ def test_groupnorm_noaffine(self): x = torch.randn(4, 6, 180, 180) self.run_test(model, x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_listunpack(self): class ListUnpack(torch.jit.ScriptModule): @@ -2462,7 +2466,6 @@ def forward(self, input, input2): input2 = torch.arange(24, dtype=torch.uint8).reshape(3, 4, 2) self.run_test(BitshiftModel(), (input, input2)) - @skipIfUnsupportedOpsetVersion([13]) def test_narrow(self): class NarrowModel(torch.nn.Module): def forward(self, input): @@ -2471,7 +2474,6 @@ def forward(self, input): x = torch.randn(3, 3, requires_grad=True) self.run_test(NarrowModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_narrow_dynamic(self): class NarrowModel(torch.nn.Module): @@ -2481,7 +2483,6 @@ def forward(self, input): x = torch.randn(3, 3, requires_grad=True) self.run_test(NarrowModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_index_fill(self): class IndexFillModel(torch.nn.Module): @@ -2492,7 +2493,6 @@ def forward(self, input): x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(IndexFillModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_index_copy(self): class IndexCopyModel(torch.nn.Module): @@ -2734,7 +2734,6 @@ def forward(self, input, indices): indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) self.run_test(GatherModel(), input=(input, indices)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_expand(self): class ExpandModel(torch.nn.Module): @@ -2806,7 +2805,6 @@ def forward(self, input): x = torch.randn(4, 5, dtype=torch.float) self.run_test(ReducedOpModule(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_reduced_sum(self): return self._test_reduced_ops(op=torch.sum) @@ -2894,7 +2892,6 @@ def forward(self, x): x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(Model(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() # scripting prim_dtype def test_lstm_no_hidden(self): @@ -2909,7 +2906,6 @@ def forward(self, x): input = torch.randn((10, 16, 16)) self.run_test(LSTMModel(), (input,)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() # scripting prim_dtype def test_lstm_proj_no_hidden(self): @@ -2926,7 +2922,6 @@ def forward(self, x): self.run_test(LSTMModel(), (input,)) @skipIfUnsupportedMinOpsetVersion(9) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_lstm(self): model = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False) @@ -2935,7 +2930,6 @@ def test_lstm(self): c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) self.run_test(model, (input, (h0, c0))) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_default_init_state(self): @@ -2943,7 +2937,6 @@ def test_lstm_default_init_state(self): input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) self.run_test(model, input) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() # LSTMModel model not scriptable def test_lstm_fixed_batch_size(self): @@ -2965,7 +2958,6 @@ def forward(self, input): input2 = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) self.run_test(LSTMModel(), input, fixed_batch_size=True, test_with_inputs=[input2]) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_post_fix_init_state(self): @@ -2990,7 +2982,6 @@ def forward(self, input): self.run_test(model, input, dynamic_axes={'input' : {0 : 'seq', 1 : 'batch'}}, test_with_inputs=[input2]) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_lstm_constant_folding(self): class LstmNet(torch.nn.Module): @@ -3018,7 +3009,6 @@ def get_LstmNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) self.run_test(model2, input2, do_constant_folding=True) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_lstm_no_bias(self): @@ -3044,7 +3034,6 @@ def get_LstmNet_model_and_inputs(num_layers, bidirectional): for model, input in models_and_inputs: self.run_test(model, input) - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() def test_rnn_no_bias(self): def make_model(layers, packed_sequence): @@ -3084,7 +3073,6 @@ def make_input(batch_size, layers, packed_sequence): for model, input in zip(models, inputs): self.run_test(model, input, batch_size=RNN_BATCH_SIZE) - @skipIfUnsupportedOpsetVersion([13]) def test_gru_no_bias(self): class GruNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, bidirectional): @@ -3114,7 +3102,6 @@ def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size, for model, input in models_and_inputs: self.run_test(model, input, do_constant_folding=True) - @skipIfUnsupportedOpsetVersion([13]) def test_gru_constant_folding(self): class GruNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, bidirectional): @@ -3273,6 +3260,72 @@ def _test_compare_ops(self, model, num_inputs): self.run_test(model, x_float) self.run_test(model, x_int) + @skipIfUnsupportedMinOpsetVersion(9) + def test_logical_and(self): + class AndModel(torch.nn.Module): + def forward(self, x, y): + return torch.logical_and(x, y) + + x = torch.randint(0, 2, (5, 5), dtype=torch.bool) + y = torch.randint(0, 2, (5, 5), dtype=torch.bool) + self.run_test(AndModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int32) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(AndModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.double) + y = torch.randint(10, (5, 5), dtype=torch.double) + self.run_test(AndModel(), input=(x, y)) + + x = torch.randint(10, (2, 3, 5), dtype=torch.float32) + y = torch.randint(10, (2, 3, 5), dtype=torch.long) + self.run_test(AndModel(), input=(x, y)) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_logical_or(self): + class OrModel(torch.nn.Module): + def forward(self, x, y): + return torch.logical_or(x, y) + + x = torch.randint(0, 2, (5, 5), dtype=torch.bool) + y = torch.randint(0, 2, (5, 5), dtype=torch.bool) + self.run_test(OrModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int32) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(OrModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.double) + y = torch.randint(10, (5, 5), dtype=torch.double) + self.run_test(OrModel(), input=(x, y)) + + x = torch.randint(10, (2, 3, 5), dtype=torch.float32) + y = torch.randint(10, (2, 3, 5), dtype=torch.long) + self.run_test(OrModel(), input=(x, y)) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_logical_xor(self): + class XorModel(torch.nn.Module): + def forward(self, x, y): + return torch.logical_xor(x, y) + + x = torch.randint(0, 2, (5, 5), dtype=torch.bool) + y = torch.randint(0, 2, (5, 5), dtype=torch.bool) + self.run_test(XorModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int32) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(XorModel(), input=(x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.double) + y = torch.randint(10, (5, 5), dtype=torch.double) + self.run_test(XorModel(), input=(x, y)) + + x = torch.randint(10, (2, 3, 5), dtype=torch.float32) + y = torch.randint(10, (2, 3, 5), dtype=torch.long) + self.run_test(XorModel(), input=(x, y)) + def test_gt(self): class GreaterModel(torch.nn.Module): def forward(self, input, other): @@ -3377,7 +3430,6 @@ def test_argmin_argmax_select_last_index(self): input = torch.ones(7, 3, 5) self._argmin_argmax_model(input) - @skipIfUnsupportedOpsetVersion([13]) def test_repeat(self): class RepeatModel(torch.nn.Module): def forward(self, x, y): @@ -3397,7 +3449,6 @@ def forward(self, input): x = torch.randint(10, (4, 2, 3, 4), dtype=torch.int32) self.run_test(ViewModel(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_view_dynamic(self): class ViewModel(torch.nn.Module): def forward(self, input, other): @@ -3407,7 +3458,6 @@ def forward(self, input, other): shape = torch.randn(6, 4) self.run_test(ViewModel(), (x, shape)) - @skipIfUnsupportedOpsetVersion([13]) def test_view_dynamic_zero_dim(self): class ViewModel(torch.nn.Module): def forward(self, input): @@ -3567,7 +3617,6 @@ def forward(self, input): self.run_test(LenModel(), x, input_names=['input'], dynamic_axes={'input': {0: 'seq'}}, test_with_inputs=(torch.randn(5, 5),)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_len_list(self): class LenListModel(torch.jit.ScriptModule): @@ -3618,7 +3667,6 @@ def forward(self, input): x = torch.randn(5, 4, 3) self.run_test(SplitModel3(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() def test_split_size_as_list(self): @@ -3635,7 +3683,6 @@ def forward(self, input, split_sizes: List[int]): split_sizes = [torch.tensor(2), torch.tensor(4)] self.run_test(SplitModel(), (x, split_sizes)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_split_size_with_slice(self): class SplitModule(torch.nn.Module): @@ -3668,6 +3715,7 @@ def forward(self, input): self.run_test(SplitModel2(), x) @skipIfUnsupportedMinOpsetVersion(11) + @disableScriptTest() def test_chunk(self): class ChunkModel(torch.nn.Module): def __init__(self): @@ -3706,7 +3754,6 @@ def forward(self, x): x = torch.randn(4, 5, 6) self.run_test(ConcatDynamicModel(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_stack(self): class StackModel(torch.nn.Module): def forward(self, x, y, z): @@ -3717,7 +3764,6 @@ def forward(self, x, y, z): z = torch.randn(3, 4, 5) self.run_test(StackModel(), (x, y, z)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_stack_dynamic(self): class StackDynamicModel(torch.jit.ScriptModule): @@ -3794,7 +3840,6 @@ def forward(self, x): x = torch.randn(5, 3, 3) self.run_test(model, x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_loop_multi_dim(self): class LoopMultiDimModel(torch.jit.ScriptModule): @@ -3809,7 +3854,6 @@ def forward(self, x, y): y = torch.ones(1, dtype=torch.long) self.run_test(model, (x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_list(self): class ListModel(torch.jit.ScriptModule): @@ -3831,7 +3875,6 @@ def forward(self, x): inputs = torch.randn(16, 1) self.run_test(model, inputs) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_factories(self): class TensorFactory(torch.nn.Module): @@ -3841,7 +3884,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -3852,7 +3894,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_tensor_like_factories_script(self): class TensorFactory(torch.jit.ScriptModule): @@ -3865,7 +3906,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(TensorFactory(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_eye(self): class TensorFactory(torch.nn.Module): @@ -3888,7 +3928,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Zero_(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_zeros(self): class Zero_(torch.nn.Module): @@ -3912,7 +3951,6 @@ def forward(self, input): x = torch.randn(2, 3) self.run_test(List(), (x,)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) @disableScriptTest() def test_list_pass(self): @@ -3952,7 +3990,6 @@ def forward(self, x, y): y = torch.randn(1, 2, 3) self.run_test(List(), (x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_empty(self): class Emtpy(torch.nn.Module): @@ -3962,7 +3999,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Emtpy(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_new_full(self): class Full(torch.nn.Module): @@ -3972,7 +4008,6 @@ def forward(self, x): x = torch.randn(2, 3, 4) self.run_test(Full(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_inplace_list(self): class Arithmetic(torch.jit.ScriptModule): @@ -4069,7 +4104,6 @@ def forward(self, x): x = torch.arange(16).view(2, 2, 4).to(torch.float32) self.run_test(MaskedFillModel2(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_masked_scatter(self): class MaskedScatterModel(torch.nn.Module): @@ -4088,7 +4122,6 @@ def forward(self, x): x = torch.randn(3, 4, 5, requires_grad=True) self.run_test(MaskedSelectModel(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # dtype not available def test_index_put_to_masked_fill(self): @@ -4103,7 +4136,6 @@ def forward(self, input_mask, some_const): constant = torch.tensor(5, dtype=torch.float) self.run_test(MaskedFillModel(), (mask, constant)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # dtype not available def test_index_put_to_masked_scatter(self): @@ -4224,7 +4256,6 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_unfold(self): class UnfoldModel(torch.nn.Module): def forward(self, x): @@ -4237,7 +4268,6 @@ def forward(self, x): input_names=['x'], test_with_inputs=[y]) - @skipIfUnsupportedOpsetVersion([13]) @skipIfONNXShapeInference(False) def test_unfold_infer_shape(self): class UnfoldModule(torch.jit.ScriptModule): @@ -4253,7 +4283,6 @@ def forward(self, x): x = torch.randn(32, 3, 64) self.run_test(UnfoldModule(), x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_unfold_dynamic_inputs(self): class UnfoldModel(torch.nn.Module): @@ -4263,7 +4292,6 @@ def forward(self, x): x = torch.randn(4, 2, 4, requires_grad=True) self.run_test(UnfoldModel(), x) - @skipIfUnsupportedOpsetVersion([13]) def test_prelu(self): class PReluModel(torch.nn.Module): def __init__(self): @@ -4399,7 +4427,6 @@ def forward(self, input): @disableScriptTest() # error in propagate as assign input shape @skipIfUnsupportedMinOpsetVersion(10) - @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag(self): model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True) input = torch.randint(10, (7,)) @@ -4416,7 +4443,6 @@ def test_embedding_bag(self): self.run_test(model, (input)) @skipIfUnsupportedMinOpsetVersion(11) - @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag_1d_per_sample_weights(self): class EmbeddingModel(torch.nn.Module): def forward(self, embedding_matrix, input, offset, weights): @@ -4431,7 +4457,6 @@ def forward(self, embedding_matrix, input, offset, weights): self.run_test(model, (embedding_matrix, x, offset, w)) @skipIfUnsupportedMinOpsetVersion(11) - @skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue def test_embedding_bag_2d_per_sample_weights(self): class EmbeddingModel(torch.nn.Module): def forward(self, embedding_matrix, input, weights): @@ -4593,7 +4618,6 @@ def forward(self, input, other): model = MyModule() self.run_test(model, (x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_ones_bool(self): class MyModule(torch.nn.Module): @@ -4640,7 +4664,6 @@ def test_constant_pad(self): self.run_test(model, x) # Dynamic padding is added in opset 11 - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @disableScriptTest() # Functional module not scriptable def test_pad_types(self): @@ -4845,7 +4868,6 @@ def test_replication_pad(self): x = torch.randn(2, 2, 4, 4) self.run_test(model, x) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_im2col(self): class Unfold(torch.nn.Module): @@ -4869,7 +4891,6 @@ def forward(self, x): # This test checks output scalar type in the ONNX graph should not be null # https://github.com/pytorch/pytorch/issues/28607 - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(10) def test_trace_script(self): @torch.jit.script @@ -4943,7 +4964,6 @@ def forward(self, *tensor_list): x = torch.randn(3, 4) self.run_test(EinsumModelTranspose(), input=(x,)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_crossentropyloss(self): for ignore_index in [-100, 1]: @@ -5042,7 +5062,6 @@ def forward(self, input, target): self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) def test_kldiv_loss(self): @@ -5109,7 +5128,6 @@ def forward(self, input, target): self.run_test(KLDivLossMiniBatchMean(), input=(x, y)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss(self): class NLLModel(torch.nn.Module): @@ -5130,7 +5148,6 @@ def forward(self, input, target): target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_none(self): class NLLModel(torch.nn.Module): @@ -5152,7 +5169,6 @@ def forward(self, input, target): target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_mean(self): class NLLModel(torch.nn.Module): @@ -5174,7 +5190,6 @@ def forward(self, input, target): target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_sum(self): class NLLModel(torch.nn.Module): @@ -5196,7 +5211,6 @@ def forward(self, input, target): target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_mean_weights(self): class NLLModel(torch.nn.Module): @@ -5218,7 +5232,6 @@ def forward(self, input, target): target[target == 1] = -100 self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_mean_ignore_index(self): class NLLModel(torch.nn.Module): @@ -5237,7 +5250,6 @@ def forward(self, input, target): target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) self.run_test(NLLModel(), (input, target)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(12) def test_nllloss_2d_mean_ignore_index_weights(self): class NLLModel(torch.nn.Module): @@ -5256,6 +5268,52 @@ def forward(self, input, target): target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) self.run_test(NLLModel(), (input, target)) + + @skipIfUnsupportedMinOpsetVersion(12) + def test_binary_cross_entropy_with_logits(self): + x = torch.randn(5) + y = torch.empty(5).random_(2) + self._bce_logits_loss(x, y) + + x = torch.randn(2, 3, 5, 7) + y = torch.empty(2, 3, 5, 7).random_(2) + weight = torch.tensor([2]) + self._bce_logits_loss(x, y, weight) + + x = torch.FloatTensor([[-0.4089, -1.2471, 0.5907], [-0.4897, -0.8267, -0.7349], [0.5241, -0.1246, -0.4751]]) + y = torch.FloatTensor([[0, 1, 1], [0, 0, 1], [1, 0, 1]]) + pos_weight = torch.empty([3]).random_(2) + self._bce_logits_loss(x, y, pos_weight) + + x = torch.randn(3, 3, 4) + y = torch.empty(3, 3, 4).random_(2) + weight = torch.tensor([3]) + pos_weight = torch.empty([3, 4]).random_(2) + self._bce_logits_loss(x, y, weight, pos_weight) + + def _bce_logits_loss(self, x, y, weight=None, pos_weight=None): + class BCEWithLogitsLossNoneWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='none') + + self.run_test(BCEWithLogitsLossNoneWeights(), input=(x, y, weight, pos_weight)) + + class BCEWithLogitsLossMeanWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='mean') + + self.run_test(BCEWithLogitsLossMeanWeights(), input=(x, y, weight, pos_weight)) + + class BCEWithLogitsLossSumWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='sum') + + self.run_test(BCEWithLogitsLossSumWeights(), input=(x, y, weight, pos_weight)) + + def test_torch_mm(self): class M(torch.nn.Module): def forward(self, mat1, mat2): @@ -5266,7 +5324,6 @@ def forward(self, mat1, mat2): mat2 = torch.randn(3, 3) self.run_test(M(), input=(mat1, mat2)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9. def test_where_with_bool_tensor(self): class M(torch.nn.Module): @@ -5278,7 +5335,6 @@ def forward(self, mat1, mat2): mat2 = torch.ones(2, 3) self.run_test(M(), input=(mat1, mat2)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(9) # Because where op is not supported for opset < 9. def test_where_with_byte_tensor(self): class M(torch.nn.Module): @@ -5477,7 +5533,6 @@ def forward(self, x): @skipIfONNXShapeInference(False) @skipIfUnsupportedMinOpsetVersion(13) - @skipIfUnsupportedOpsetVersion([13]) def test_if_list(self): class IfModel(torch.nn.Module): def forward(self, x, y, cond): @@ -5668,7 +5723,6 @@ def forward(self, input): x = torch.randn(6, 4, 3, 3) self.run_test(FakeQuantizePerTensorModel(), (x)) - @skipIfUnsupportedOpsetVersion([13]) def test_batchnorm_training(self): class MyModule(torch.nn.Module): def __init__(self): @@ -5792,7 +5846,6 @@ def forward(self, x): np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) - @skipIfUnsupportedOpsetVersion([13]) def test_conv_bn(self): class MyModule(torch.nn.Module): def __init__(self): @@ -5991,7 +6044,6 @@ def forward(self, boxes, scores): self.run_test(Module(), (boxes, scores)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_clip_boxes_to_image(self): boxes = torch.randn(5, 4) * 500 @@ -6007,7 +6059,7 @@ def forward(self, boxes, size): self.run_test(Module(), (boxes, size), input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]}, - test_with_inputs=[(boxes, size_2)]) + test_with_inputs=[(boxes, size), (boxes, size_2)]) @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @@ -6049,7 +6101,6 @@ def test_roi_pool(self): model = ops.RoIPool((pool_h, pool_w), 2) self.run_test(model, (x, rois)) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_resize_images(self): class TransformModule(torch.nn.Module): @@ -6064,9 +6115,8 @@ def forward(self, images): input_test = torch.rand(3, 100, 150) self.run_test(TransformModule(), (input,), input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}, - test_with_inputs=[(input_test,)]) + test_with_inputs=[(input,), (input_test,)]) - @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_transform_images(self): @@ -6080,7 +6130,7 @@ def forward(self, images): input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200) - self.run_test(TransformModule(), (input,), test_with_inputs=[(input_test,)]) + self.run_test(TransformModule(), (input,), test_with_inputs=[(input,), (input_test,)]) def get_features(self, images): s0, s1 = images.shape[-2:] @@ -6097,6 +6147,7 @@ def get_features(self, images): @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) def test_rpn(self): + class RPNModule(torch.nn.Module): def __init__(self): super(RPNModule, self).__init__() @@ -6119,7 +6170,7 @@ def forward(self, images, features): dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}, - test_with_inputs=[(images2, test_features)], + test_with_inputs=[(images, features), (images2, test_features)], dict_check=False) @skipIfUnsupportedOpsetVersion([13]) @@ -6147,7 +6198,7 @@ def forward(self, input, boxes): boxes1 = torch.rand(6, 4) * 256 boxes1[:, 2:] += boxes1[:, :2] - self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i1, [boxes1],)]) + self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i, [boxes],), (i1, [boxes1],)]) @skipIfUnsupportedOpsetVersion([13]) @skipIfUnsupportedMinOpsetVersion(11) @@ -6182,7 +6233,7 @@ def forward(self, images, features): input_names=["input1", "input2", "input3", "input4", "input5", "input6"], dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}, - test_with_inputs=[(images2, test_features)], + test_with_inputs=[(images, features), (images2, test_features)], dict_check=False) @@ -6197,7 +6248,6 @@ def make_test(name, base, layer, bidirectional, initial_state, # Cannot export with older opsets because of 'ConstantFill' op # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime - @skipIfUnsupportedOpsetVersion([13]) @disableScriptTest() # Test code not scriptable @skipIfUnsupportedMinOpsetVersion(9) def f(self): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py new file mode 100644 index 000000000000..24017b125b10 --- /dev/null +++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py @@ -0,0 +1,31 @@ +import unittest +import onnxruntime # noqa +import torch + +from test_pytorch_common import skipIfUnsupportedMinOpsetVersion +from test_pytorch_common import skipIfNoCuda + +from test_pytorch_onnx_onnxruntime import TestONNXRuntime + +class TestONNXRuntime_cuda(unittest.TestCase): + from torch.onnx.symbolic_helper import _export_onnx_opset_version + opset_version = _export_onnx_opset_version + keep_initializers_as_inputs = True + use_new_jit_passes = True + onnx_shape_inference = True + + @skipIfUnsupportedMinOpsetVersion(9) + @skipIfNoCuda + def test_gelu_fp16(self): + class GeluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.gelu(x) + + x = torch.randn(2, 4, 5, 6, requires_grad=True, dtype=torch.float16, device=torch.device('cuda')) + self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5) + +TestONNXRuntime_cuda.setUp = TestONNXRuntime.setUp +TestONNXRuntime_cuda.run_test = TestONNXRuntime.run_test + +if __name__ == '__main__': + unittest.main(TestONNXRuntime_cuda()) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 5c1bfe8b5515..3aadabf85769 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -5,7 +5,7 @@ from torch.onnx import utils, OperatorExportTypes, TrainingMode from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type import torch.utils.cpp_extension -from test_pytorch_common import skipIfUnsupportedMinOpsetVersion +from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion import caffe2.python.onnx.backend as backend from verify import verify @@ -618,6 +618,8 @@ def forward(self, x): assert next(iter).kind() == "aten::quantize_per_tensor" assert next(iter).kind() == "aten::dequantize" + # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11 + @skipIfUnsupportedOpsetVersion([11, 12]) def test_prim_fallthrough(self): # Test prim op class PrimModule(torch.jit.ScriptModule): diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 8d7ba7dcb4e6..b2243eead1d0 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1497,7 +1497,7 @@ def forward(self, x): self.checkGraphModeFxOp(model, data, quant_type, quantized_node) @skipIfNoFBGEMM - def test_linear_functional(self): + def test_functional_linear(self): class FuncLinear(torch.nn.Module): def __init__(self, use_bias, has_relu, f_relu): super(FuncLinear, self).__init__() @@ -1595,22 +1595,80 @@ def forward(self, x): quantized_nodes[dim]) @skipIfNoFBGEMM - def test_conv2d_functional(self): - for bias in [True, False]: - conv = torch.nn.Conv2d(1, 1, 1, bias=bias) + def test_functional_conv(self): + """ Test for function conv and functional conv + relu + """ + class FuncConv(torch.nn.Module): + def __init__(self, use_bias, has_relu, f_relu): + super().__init__() + self.w = torch.randn(3, 3, 3, 3) + self.b = torch.randn(3) if use_bias else None + self.stride = (1, 1) + self.padding = (0, 0) + self.dilation = (1, 1) + self.groups = 1 + self.use_bias = use_bias + if has_relu: + if f_relu: + self.relu = F.relu + else: + self.relu = torch.nn.ReLU() + else: + self.relu = torch.nn.Identity() + + def forward(self, x): + x = F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) + x = self.relu(x) + return x + + data = (torch.randn((2, 3, 4, 4), dtype=torch.float),) + + quant_type_to_prepare_expected_node_occurrence = { + QuantType.DYNAMIC: {}, # There should be 3 observers: after input, weight and activation. - # No observer after bias. - prepare_expected_node_occurrence = { + QuantType.STATIC: { ns.call_module(torch.quantization.HistogramObserver): 2, ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, + }, + # There should be 3 observers: after input, weight and activation. + QuantType.QAT: { + ns.call_module(torch.quantization.FakeQuantize): 3, + }, + } + quant_type_to_qconv_fun = { + QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d), + QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d), + } + quant_type_to_qconv_relu_fun = { + QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d_relu), + QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d_relu), + } + + options = itertools.product( + self.static_quant_types, + (True, False), # use_bias + (True, False), # has_relu + (True, False), # functional relu + ) + for quant_type, use_bias, has_relu, f_relu in options: + model = FuncConv(use_bias, has_relu, f_relu) + if has_relu: + qconv_fun = quant_type_to_qconv_relu_fun[quant_type] + else: + qconv_fun = quant_type_to_qconv_fun[quant_type] + + convert_node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0, + qconv_fun: 1, + ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0 } - expected_node_occurrence = \ - {ns.call_function(torch.ops.quantized.conv2d): 1} + prepare_expected_node_occurrence = \ + quant_type_to_prepare_expected_node_occurrence[quant_type] self.checkGraphModeFxOp( - conv, (torch.randn(4, 1, 4, 4),), QuantType.STATIC, + model, data, quant_type, qconv_fun, prepare_expected_node_occurrence=prepare_expected_node_occurrence, - expected_node_occurrence=expected_node_occurrence, - ) + expected_node_occurrence=convert_node_occurrence) + @skipIfNoFBGEMM def test_quantized_conv_relu(self): diff --git a/test/test_autograd.py b/test/test_autograd.py index 28ef8d41b346..df588f97701a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7377,6 +7377,20 @@ def test_inplace_view_multiple_outputs(self, device): with self.assertRaises(RuntimeError): v1[0].mul_(2) + def test_inplace_view_of_multiple_output_view(self, device): + a = torch.rand(10, device=device, requires_grad=True).clone() + b = a.unbind(0) + c = b[0].view_as(b[0]) + with self.assertRaises(RuntimeError): + c.mul_(2) + + def test_inplace_multiple_output_view_of_view(self, device): + a = torch.rand(10, device=device, requires_grad=True).clone() + b = a.view_as(a) + c = b.unbind(0) + with self.assertRaises(RuntimeError): + c[0].mul_(2) + def test_inplace_view_makes_base_require_grad(self, device): # in-place modification to view makes base require grad a = torch.randn(4, 4, device=device, requires_grad=False) diff --git a/test/test_indexing.py b/test/test_indexing.py index b92fd94e8cbd..10e4a9bafe95 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -762,9 +762,9 @@ def test_int_indices(self, device): self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) - @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool) - @dtypesIfCPU(torch.float, torch.long, torch.bool, torch.bfloat16) - @dtypesIfCUDA(torch.half, torch.long, torch.bool, torch.bfloat16) + @dtypes(torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool) + @dtypesIfCPU(torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16) + @dtypesIfCUDA(torch.cfloat, torch.cdouble, torch.half, torch.long, torch.bool, torch.bfloat16) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) diff --git a/test/test_jit.py b/test/test_jit.py index 4d37cd0a3ef9..17f7f66ac8a5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -65,7 +65,7 @@ freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ - _trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, \ + _trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, make_global, \ execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ RUN_CUDA from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, nn_functional_tests, get_script_args, \ @@ -6609,8 +6609,6 @@ def bar(c, b): .check("in foo").check("in baz").run(str(cm.exception)) def test_error_stacktrace_interface(self): - global IFace - @torch.jit.script def baz(c, b): return c + b @@ -6634,6 +6632,8 @@ def one(self, x, y): # type: (Tensor, Tensor) -> Tensor pass + make_global(IFace) + @torch.jit.script def as_interface(x): # type: (IFace) -> IFace @@ -15682,7 +15682,7 @@ def check(name): def fn(*inputs, **kwargs): attr = getattr(inputs[0], name) output = attr(*inputs[1:], **kwargs) - return output_process_fn(output) + return output check_types = test_name not in EXCLUDE_TYPE_CHECK # XXX: this test should always run with disable_autodiff_subgraph_inlining(True), @@ -15698,7 +15698,7 @@ def fn(*inputs, **kwargs): traced_fn = create_traced_fn(self, fn) check_against_reference(self, traced_fn, - fn, (self_variable,) + args_variable, kwargs_variable, + fn, output_process_fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) if IS_SANDCASTLE: autodiff_nodes = autodiff_nodes + fusible_nodes @@ -15708,9 +15708,9 @@ def fn(*inputs, **kwargs): self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if not is_magic_method and test_name not in EXCLUDE_SCRIPT: - script_fn = create_script_fn(self, name, 'method', output_process_fn) + script_fn = create_script_fn(self, name, 'method') check_against_reference(self, script_fn, - fn, (self_variable,) + args_variable, kwargs_variable, + fn, output_process_fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) if IS_SANDCASTLE: @@ -15725,21 +15725,20 @@ def fn(*inputs, **kwargs): # functional interface tests if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: def fn(*inputs, **kwargs): - output = getattr(torch, name)(*inputs, **kwargs) - return output_process_fn(output) + return getattr(torch, name)(*inputs, **kwargs) f_args_variable = (self_variable,) + args_variable f_args_tensor = (self_tensor,) + args_tensor if not is_inplace and test_name not in EXCLUDE_TRACED: check_against_reference(self, - create_traced_fn(self, fn), - fn, f_args_variable, kwargs_variable, check_types=check_types) + create_traced_fn(self, fn), fn, output_process_fn, + f_args_variable, kwargs_variable, check_types=check_types) if not is_inplace and test_name not in EXCLUDE_SCRIPT: check_against_reference(self, - create_script_fn(self, name, 'functional', output_process_fn), - fn, f_args_variable, kwargs_variable, + create_script_fn(self, name, 'functional'), + fn, output_process_fn, f_args_variable, kwargs_variable, check_types=check_types) # alias annotation testing @@ -15781,8 +15780,7 @@ def do_test(self, name=name, args=args, test_name=test_name, check_ad=check_ad): output_variable = getattr(F, name)(self_variable, *args_variable, **kwargs_variable) def fn(*inputs, **kwargs): - output = getattr(F, name)(*inputs, **kwargs) - return output_process_fn(output) + return getattr(F, name)(*inputs, **kwargs) f_args_variable = (self_variable,) + args_variable f_args_tensor = (self_tensor,) + args_tensor @@ -15793,8 +15791,9 @@ def run_test(): # XXX: this test should always run with disable_autodiff_subgraph_inlining(True), # so that we don't regress on autodiff support. with disable_autodiff_subgraph_inlining(): - script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn) - check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad) + script_fn = create_script_fn(self, name, 'nn_functional') + check_against_reference(self, script_fn, fn, output_process_fn, + f_args_variable, kwargs_variable, no_grad=no_grad) # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd if (doAutodiffCheck(test_name)): self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) @@ -15914,7 +15913,8 @@ def create_nn_module(*args, **kwargs): f_args_variable = deepcopy(unpack_variables(args_variable)) # Check against Python module as reference - check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad) + check_against_reference(self, create_script_module, create_nn_module, + lambda x: x, f_args_variable, no_grad=no_grad) if 'slowTest' in kwargs: do_test = slowTest(do_test) diff --git a/test/test_linalg.py b/test/test_linalg.py index f2ee0dcaaef9..4f061dccc0de 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -828,6 +828,18 @@ def run_test_skipped_elements(a_shape, b_shape): # run_test_transposed(a_shape, b_shape) run_test_skipped_elements(a_shape, b_shape) + # Test that kron perserve memory format + a = torch.randn(1, 2, 3, 4, dtype=dtype, device=device).contiguous(memory_format=torch.channels_last) + b = torch.randn(1, 2, 3, 4, dtype=dtype, device=device).contiguous(memory_format=torch.channels_last) + c = torch.kron(a, b) + self.assertTrue(c.is_contiguous(memory_format=torch.channels_last)) + torch.kron(a, b, out=c) + self.assertTrue(c.is_contiguous(memory_format=torch.channels_last)) + c = c.contiguous(memory_format=torch.contiguous_format) + torch.kron(a, b, out=c) + self.assertTrue(c.is_contiguous(memory_format=torch.contiguous_format)) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_kron_empty(self, device, dtype): @@ -4715,7 +4727,7 @@ def test_bmm(self, device, dtype): # undefined bahavior return - num_batches = 10 + batch_sizes = [1, 10] M, N, O = 23, 8, 12 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 @@ -4724,17 +4736,18 @@ def test_bmm(self, device, dtype): is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) if not is_supported: - b1 = torch.randn(num_batches, M, N, device=device).to(dtype) - b2 = torch.randn(num_batches, N, O, device=device).to(dtype) - self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", - lambda: torch.bmm(b1, b2)) + for num_batches in batch_sizes: + b1 = torch.randn(num_batches, M, N, device=device).to(dtype) + b2 = torch.randn(num_batches, N, O, device=device).to(dtype) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", + lambda: torch.bmm(b1, b2)) return def invert_perm(p): d = {x: i for i, x in enumerate(p)} return (d[0], d[1], d[2]) - def generate_inputs(): + def generate_inputs(num_batches): # transposed tensors for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) @@ -4757,21 +4770,22 @@ def generate_inputs(): b2 = torch.randn(shape2, dtype=dtype, device=device) yield b1, b2 - for (b1, b2), perm3 in itertools.product(generate_inputs(), itertools.permutations((0, 1, 2))): - res1 = torch.bmm(b1, b2) - res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ - .permute(perm3).contiguous().permute(invert_perm(perm3)) - torch.bmm(b1, b2, out=res2) - expect = torch.from_numpy( - b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) - self.assertEqual(expect, res1) - self.assertEqual(expect, res2) + for num_batches in batch_sizes: + for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))): + res1 = torch.bmm(b1, b2) + res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ + .permute(perm3).contiguous().permute(invert_perm(perm3)) + torch.bmm(b1, b2, out=res2) + expect = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) + self.assertEqual(expect, res1) + self.assertEqual(expect, res2) - if self.device_type == 'cuda': - # check that mixed arguments are rejected - self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) - self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) - self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())) + if self.device_type == 'cuda': + # check that mixed arguments are rejected + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA diff --git a/test/test_ops.py b/test/test_ops.py index bd82aca3820a..a01f96fe877b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,6 @@ import torch -from torch.testing import floating_and_complex_types_and from torch.testing._internal.common_utils import \ (TestCase, run_tests, IS_SANDCASTLE, clone_input_helper) from torch.testing._internal.common_methods_invocations import \ @@ -191,7 +190,8 @@ def check_variant_backward(self, input, forward_result, expected_grad, expected_ # against eager's gold standard op function variant @ops(op_db) def test_variant_consistency_eager(self, device, dtype, op): - samples = op.sample_inputs(device, dtype, requires_grad=True) + test_backward = op.test_complex_grad or not dtype.is_complex + samples = op.sample_inputs(device, dtype, requires_grad=test_backward) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") @@ -237,7 +237,7 @@ def test_variant_consistency_eager(self, device, dtype, op): self.assertEqual(variant_forward, expected_forward) # Compares variant's backward - if variant is not inplace or op.test_inplace_grad: + if test_backward and (variant is not inplace or op.test_inplace_grad): self.check_variant_backward(sample.input, variant_forward, expected_grad, exception_during_backwards) @@ -247,7 +247,10 @@ def test_variant_consistency_eager(self, device, dtype, op): # TODO WARNING: inplace x {traced, scripted} not currently tested @ops(op_db) def test_variant_consistency_jit(self, device, dtype, op): - samples = op.sample_inputs(device, dtype, requires_grad=True) + test_backward = ( + (dtype.is_complex and op.test_complex_grad) or + (dtype.is_floating_point and (not op.skip_bfloat16_grad or dtype != torch.bfloat16))) + samples = op.sample_inputs(device, dtype, requires_grad=test_backward) if len(samples) == 0: self.skipTest("Skipped! No sample inputs!") @@ -275,32 +278,28 @@ def test_variant_consistency_jit(self, device, dtype, op): # autodiff support. Context manager forces the graph to contain # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): - def fn(*inputs, **kwargs): - output = func(*inputs, **kwargs) - return op.output_func(output) - # bfloat16 grad doesn't work for some operators - dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \ - if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16) # Check scripted forward, grad, and grad grad - script_fn = create_script_fn(self, name, func_type, op.output_func) + script_fn = create_script_fn(self, name, func_type) check_against_reference(self, script_fn, - fn, + func, + op.output_func, (*sample.input,) + sample.args, sample.kwargs, - no_grad=(dtype not in dtypes_to_grad_check)) + no_grad=not test_backward) # Check traced forward, grad, and grad grad traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, - fn, + func, + op.output_func, (*sample.input,) + sample.args, sample.kwargs, - no_grad=(dtype not in dtypes_to_grad_check)) + no_grad=not test_backward) # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) diff --git a/test/test_profiler.py b/test/test_profiler.py index 826a9f5d0b57..9dfe099163a7 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -267,7 +267,7 @@ def trace_handler(p): ) as p: for idx in range(8): self.payload() - p.next_step() + p.step() self.assertEqual(called_num[0], 2) diff --git a/test/test_torch.py b/test/test_torch.py index 4ace1012167d..424ab4f40c0f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1191,49 +1191,6 @@ def test_scatterReduce(self): for method in ["add", "multiply"]: self._test_scatter_base(self, lambda t: t, 'scatter_', reduction=method) - def test_masked_scatter(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - for maskType in [torch.uint8, torch.bool]: - for dt in torch.testing.get_all_dtypes(): - num_copy, num_dest = 3, 10 - dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt) - dest2 = dest.clone() - src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt) - mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType) - - if dt == torch.bool: - # torch.bool is a special case and is being tested - # in a separate test - continue - - # TODO: update test when masked scatter is supported for complex - if dt == torch.half or dt.is_complex: - self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src)) - continue - - dest.masked_scatter_(mask, src) - j = 0 - for i in range(num_dest): - if mask[i]: - dest2[i] = src[j] - j += 1 - self.assertEqual(dest, dest2, atol=0, rtol=0) - - # make source bigger than number of 1s in mask - src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt) - dest.masked_scatter_(mask, src) - - # make src smaller. this should fail - src = torch.zeros(num_copy - 1, dtype=dt) - with self.assertRaises(RuntimeError): - dest.masked_scatter_(mask, src) - self.assertEqual(len(w), 27) - - warn = 'masked_scatter_ received a mask with dtype torch.uint8,' - for wi in w: - self.assertEqual(str(wi.message)[0:55], str(warn)) - def test_masked_fill(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -4521,6 +4478,56 @@ def test_scatter_add_bool(self, device): [False, True, False, True, False], [True, False, True, False, True]], device=device)) + @onlyOnCPUAndCUDA + @dtypes(*torch.testing.get_all_dtypes()) + def test_masked_scatter(self, device, dtype): + dt = dtype + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + for maskType in [torch.uint8, torch.bool]: + num_copy, num_dest = 3, 10 + dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device) + dest2 = dest.clone() + dest_ones = dest.clone() + dest_ones_expected = dest.clone() + src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device) + src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device) + mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType, device=device) + + if dt == torch.bool: + # torch.bool is a special case and is being tested + # in a separate test + return + + # TODO: update test when masked scatter is supported for complex + # and cpu supports half + if (dt == torch.half and self.device_type == 'cpu') or dt.is_complex: + self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src)) + return + + dest.masked_scatter_(mask, src) + j = 0 + for i in range(num_dest): + if mask[i]: + dest2[i] = src[j] + dest_ones_expected[i] = src_ones[j] + j += 1 + self.assertEqual(dest, dest2, atol=0, rtol=0) + + dest_ones.masked_scatter_(mask, src_ones) + self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0) + + # make src smaller. this should fail + src = torch.zeros(num_copy - 1, dtype=dt, device=device) + with self.assertRaises(RuntimeError): + dest.masked_scatter_(mask, src) + + self.assertEqual(len(w), 3) + + warn = 'masked_scatter_ received a mask with dtype torch.uint8,' + for wi in w: + self.assertEqual(str(wi.message)[0:55], str(warn)) + def test_masked_scatter_bool_tensor(self, device): src = torch.tensor([True, True, True], device=device) dst = torch.tensor([False, False, False], device=device) @@ -6723,12 +6730,6 @@ def inner(self, device, dtype): ('remainder', 'negative_tensor', _small_3d, lambda t, d: [0 - _small_3d(t, d, has_zeros=False)], 1e-1, 1e-2, 1e-5, _signed_types), - ('std', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), - ('std', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), - ('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), - ('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), - ('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, 1e-5, _float_types, _cpu_types, False), - ('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False), ('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), @@ -6854,7 +6855,6 @@ def inner(self, device, dtype): ('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), ('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]), - ('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]), ] # Creates and decorates a generic test and adds it to the class. diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 365c33179206..3497ccd04cc1 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1684,8 +1684,6 @@ def _medium_2d(dtype, device): _TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)), _TorchMathTestMeta('trunc'), _TorchMathTestMeta('round'), - # FIXME lgamma produces different result compared to scipy at -inf - _TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy', replace_inf_with_nan=True), _TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma', refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False], ref_backend='scipy'), diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a7ebf35cc0b6..6ebae6b80c1f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -261,7 +261,7 @@ def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... -def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ... def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ... def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... @@ -298,7 +298,7 @@ def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IV def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ... def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... -def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ... +def _jit_pass_onnx_node_shape_type_inference(n: Node, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... def _jit_pass_onnx_block( old_block: Block, new_block: Block, diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index be287d0a9a3b..ad7a2cf4ba88 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -241,7 +241,9 @@ def get_annotation_str(annotation): elif isinstance(annotation, ast.Attribute): return '.'.join([get_annotation_str(annotation.value), annotation.attr]) elif isinstance(annotation, ast.Subscript): - return f"{get_annotation_str(annotation.value)}[{get_annotation_str(annotation.slice.value)}]" # type: ignore + # In Python3.9+ subscript indicies are not wrapped in ast.Index + subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore + return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" elif isinstance(annotation, ast.Tuple): return ','.join([get_annotation_str(elt) for elt in annotation.elts]) elif isinstance(annotation, ast.Constant) or isinstance(annotation, ast.NameConstant): @@ -598,10 +600,11 @@ def _copy_to_script_wrapper(fn): def module_has_exports(mod): for name in dir(mod): - item = getattr(mod, name) - if callable(item): - if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: - return True + if hasattr(mod, name): + item = getattr(mod, name) + if callable(item): + if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: + return True return False def should_drop(fn): diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index a3d0da1aef9d..0e50dcee702e 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -223,7 +223,7 @@ def export_chrome_trace(self, path): '"pid": "CPU functions", ' '"args": {}}, ' % ( - evt.name, + evt.trace_name, evt.time_range.start, evt.time_range.elapsed_us(), evt.thread @@ -241,7 +241,7 @@ def export_chrome_trace(self, path): '"pid": "CPU functions", ' '"id": %s, ' '"cat": "cpu_to_cuda", ' - '"args": {}}, ' % (evt.name, evt.time_range.start, + '"args": {}}, ' % (evt.trace_name, evt.time_range.start, evt.thread, next_id)) f.write('{"name": "%s", ' '"ph": "f", ' @@ -847,10 +847,11 @@ def __init__( self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None, stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0, - is_legacy=False, flops=None): + is_legacy=False, flops=None, trace_name=None): self.id: int = id self.node_id: int = node_id self.name: str = name + self.trace_name: str = trace_name if trace_name is not None else self.name self.time_range: Interval = Interval(start_us, end_us) self.thread: int = thread self.fwd_thread: Optional[int] = fwd_thread @@ -1101,6 +1102,18 @@ def filter_name(name): ] return name in filtered_out_names +# Demangles and optionally rewrites the provided event name, +# with_wildcard - whether to replace certain numbered event names +# with a wildcard name to aggregate them together in the profiler table +# output +def rewrite_name(name, with_wildcard=False): + string_table = StringTable() + name = string_table[name] + if with_wildcard: + if name.startswith("ProfilerStep#"): + name = "ProfilerStep*" + return name + # Parsing of kineto profiler events def parse_kineto_results(result): # result.events() has most of the events - PyTorch op-level and device-level events @@ -1120,7 +1133,6 @@ def parse_kineto_results(result): assert start_record is not None, "Invalid profiler output, __start_profile is missing" # Create and return FunctionEvent list - string_table = StringTable() function_events = [] cuda_corr_map: Dict[int, List[torch.autograd.KinetoEvent]] = {} for kineto_event in result.events(): @@ -1142,7 +1154,8 @@ def parse_kineto_results(result): is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() fe = FunctionEvent( id=kineto_event.correlation_id(), - name=string_table[kineto_event.name()], + name=rewrite_name(name=kineto_event.name(), with_wildcard=True), + trace_name=rewrite_name(name=kineto_event.name(), with_wildcard=False), thread=kineto_event.start_thread_id(), start_us=rel_start_us, end_us=rel_end_us, @@ -1193,7 +1206,6 @@ def get_record_key(record): cuda_records = {} functions = [] record_stack = [] - string_table = StringTable() # cuda start events and the overall profiler start event don't happen # at exactly the same time because we need to record an event on each device @@ -1271,7 +1283,8 @@ def adjusted_time(cuda_record, cuda_records_map): fe = FunctionEvent( id=record.handle(), node_id=record.node_id(), - name=string_table[start.name()], + name=rewrite_name(name=start.name(), with_wildcard=True), + trace_name=rewrite_name(name=start.name(), with_wildcard=False), thread=start.thread_id(), start_us=start_record.cpu_elapsed_us(start), end_us=start_record.cpu_elapsed_us(record), @@ -1569,6 +1582,14 @@ def append(s): append(header_sep) + def trim_path(path, src_column_width): + if len(path) > src_column_width: + offset = len(path) - src_column_width + path = path[offset:] + if len(path) > 3: + path = "..." + path[3:] + return path + event_limit = 0 for evt in events: if event_limit == row_limit: @@ -1629,14 +1650,14 @@ def append(s): if has_stack: src_field = "" if len(evt.stack) > 0: - src_field = evt.stack[0][:src_column_width] + src_field = trim_path(evt.stack[0], src_column_width) row_values.append(src_field) append(row_format.format(*row_values)) if has_stack: empty_headers = [""] * (len(headers) - 1) for entry in evt.stack[1:MAX_STACK_ENTRY]: - append(row_format.format(*(empty_headers + [entry[:src_column_width]]))) + append(row_format.format(*(empty_headers + [trim_path(entry, src_column_width)]))) empty_headers.append("") append(row_format.format(*empty_headers)) diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 2894a75fed69..85b83f2aa6ee 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -145,6 +145,7 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif if (base.is_view()) { auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); const auto& base_bw_info = diff_view_meta->get_backward_view(); + creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func), c10::nullopt, creation_meta, allow_tensor_metadata_change); } else { @@ -188,6 +189,10 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif } if (is_fw_differentiable || is_bw_differentiable) { + if (base.is_view()) { + auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); + creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); + } return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info), creation_meta, allow_tensor_metadata_change); } else { @@ -234,6 +239,11 @@ inline std::vector as_view(const Tensor & base, std::vector& ten } } + if ((is_fw_differentiable || is_bw_differentiable) && base.is_view()) { + auto diff_view_meta = static_cast(torch::autograd::impl::get_autograd_meta(base)); + creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta); + } + for(Tensor &tensor : tensors) { if (is_fw_differentiable || is_bw_differentiable) { tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta); diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 9cdf40fe2c63..ad8c1919dee6 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -502,6 +502,15 @@ struct TORCH_API ViewInfo { enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NODE, NO_GRAD_MODE, MULTI_OUTPUT_SAFE }; +/// Handles correctly propagating CreationMeta when a new view is created from a previous view. +/// In general, we don't want the new view to be _less_ restrictive than the previous view +/// (it's okay to be _more_ restrictive). A CreationMeta value of DEFAULT is currently the least +/// restrictive, as the behavior for all other CreationMeta values is to error out for in-place ops. +/// If this changes, the logic here will need to be updated to properly handle the new semantics. +inline CreationMeta propagate_creation_meta(CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) { + return (new_view_creation_meta == CreationMeta::DEFAULT) ? prev_view_creation_meta : new_view_creation_meta; +} + /// Unified function to handle error checking when rebase happens /// indirect=true means that the caller is not doing the inplace, but the inplace happened /// somewhere else. @@ -657,7 +666,7 @@ inline Variable make_variable( bool allow_tensor_metadata_change = true) { if (data.defined()) { if (data.getIntrusivePtr().use_count() == 1 && data.getIntrusivePtr()->unique_version()) { - auto data_impl = data.getIntrusivePtr(); + auto data_impl = data.unsafeReleaseIntrusivePtr(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); if (requires_grad) { data_impl->set_autograd_meta(std::make_unique(data_impl.get(), requires_grad)); diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp index 509c5c6cbd08..33df6b540000 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -13,10 +13,12 @@ using torch::autograd::variable_list; RecvRpcBackward::RecvRpcBackward( const AutogradMetadata& autogradMetadata, ContextPtr autogradContext, - rpc::worker_id_t fromWorkerId) + rpc::worker_id_t fromWorkerId, + std::unordered_map deviceMap) : autogradMetadata_(autogradMetadata), autogradContext_(std::move(autogradContext)), - fromWorkerId_(fromWorkerId) {} + fromWorkerId_(fromWorkerId), + deviceMap_(std::move(deviceMap)) {} variable_list RecvRpcBackward::apply(variable_list&& grads) { std::vector outputGrads; @@ -49,7 +51,9 @@ variable_list RecvRpcBackward::apply(variable_list&& grads) { auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); auto jitFuture = rpcAgent->send( rpcAgent->getWorkerInfo(fromWorkerId_), - std::move(gradCall).toMessage()); + std::move(gradCall).toMessage(), + rpc::kUnsetRpcTimeout, + deviceMap_); // Record the future in the context. sharedContext->addOutstandingRpc(jitFuture); diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h index 982e0331c102..69be98c928ef 100644 --- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h @@ -22,7 +22,8 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { explicit RecvRpcBackward( const AutogradMetadata& autogradMetadata, std::shared_ptr autogradContext, - rpc::worker_id_t fromWorkerId); + rpc::worker_id_t fromWorkerId, + std::unordered_map deviceMap); torch::autograd::variable_list apply( torch::autograd::variable_list&& grads) override; @@ -38,6 +39,9 @@ class TORCH_API RecvRpcBackward : public torch::autograd::Node { // The worker id from which the RPC was received. During the backward pass, // we need to propagate the gradients to this workerId. rpc::worker_id_t fromWorkerId_; + + // Device mapping for tensors sent over RPC. + const std::unordered_map deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index 7389868d90c2..5aea96fa0c8b 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -18,11 +18,13 @@ RpcWithAutograd::RpcWithAutograd( worker_id_t fromWorkerId, MessageType messageType, const AutogradMetadata& autogradMetadata, - rpc::Message&& wrappedMessage) + rpc::Message&& wrappedMessage, + std::unordered_map deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), - wrappedMessage_(std::move(wrappedMessage)) { + wrappedMessage_(std::move(wrappedMessage)), + deviceMap_(std::move(deviceMap)) { TORCH_INTERNAL_ASSERT( messageType_ == MessageType::FORWARD_AUTOGRAD_REQ || messageType_ == MessageType::FORWARD_AUTOGRAD_RESP); @@ -36,13 +38,15 @@ RpcWithAutograd::RpcWithAutograd( const AutogradMetadata& autogradMetadata, std::unique_ptr wrappedRpc, MessageType wrappedMessageType, - std::vector tensors) + std::vector tensors, + std::unordered_map deviceMap) : fromWorkerId_(fromWorkerId), messageType_(messageType), autogradMetadata_(autogradMetadata), wrappedRpc_(std::move(wrappedRpc)), wrappedMessageType_(wrappedMessageType), - tensors_(std::move(tensors)) { + tensors_(std::move(tensors)), + deviceMap_(std::move(deviceMap)) { TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!"); TORCH_INTERNAL_ASSERT( messageType_ == MessageType::FORWARD_AUTOGRAD_REQ || @@ -56,10 +60,17 @@ Message RpcWithAutograd::toMessageImpl() && { auto payload = std::move(wrappedMessage_).movePayload(); TORCH_INTERNAL_ASSERT(!payload.empty()); + // Convert deviceMap to c10::Dict for serialization. + c10::Dict deviceMap; + for (const auto& mapEntry : deviceMap_) { + deviceMap.insert(mapEntry.first, mapEntry.second); + } + std::vector ivalues{wrappedMessageType, autogradMetadata_.autogradContextId, autogradMetadata_.autogradMessageId, - fromWorkerId_}; + fromWorkerId_, + deviceMap}; // Now pickle using JIT pickler. std::vector tensorTable; @@ -92,12 +103,19 @@ std::unique_ptr RpcWithAutograd::fromMessage( auto tupleElements = rpc::readWrappedPayload(payload, message); // Gather all the fields. - TORCH_INTERNAL_ASSERT(tupleElements.size() == 4); + TORCH_INTERNAL_ASSERT(tupleElements.size() == 5); MessageType wrappedMessageType = static_cast(tupleElements[0].toInt()); AutogradMetadata autogradMetadata( tupleElements[1].toInt(), tupleElements[2].toInt()); worker_id_t workerId = tupleElements[3].toInt(); + auto c10DeviceMap = tupleElements[4].to>(); + + // Convert to regular map. + std::unordered_map deviceMap; + for (const auto& mapEntry : c10DeviceMap) { + deviceMap.insert({mapEntry.key(), mapEntry.value()}); + } // Create new message type and build wrapped RPC. Message wrappedMessage( @@ -116,7 +134,8 @@ std::unique_ptr RpcWithAutograd::fromMessage( autogradMetadata, std::move(wrappedRpc), wrappedMessageType, - wrappedMessage.tensors()); + wrappedMessage.tensors(), + deviceMap); } std::vector& RpcWithAutograd::tensors() { @@ -150,6 +169,11 @@ rpc::worker_id_t RpcWithAutograd::fromWorkerId() const { return fromWorkerId_; } +const std::unordered_map& RpcWithAutograd:: + deviceMap() { + return deviceMap_; +} + } // namespace autograd } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h index 657d2cf2641f..f4728ea37c63 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h @@ -18,7 +18,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { rpc::worker_id_t fromWorkerId, rpc::MessageType messageType, const AutogradMetadata& autogradMetadata, - rpc::Message&& wrappedMessage); + rpc::Message&& wrappedMessage, + std::unordered_map deviceMap = {}); // Used when receiving an RPC over the wire. RpcWithAutograd( @@ -27,7 +28,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { const AutogradMetadata& autogradMetadata, std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, - std::vector tensors); + std::vector tensors, + std::unordered_map deviceMap = {}); rpc::Message toMessageImpl() && override; @@ -52,6 +54,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { // Retrieve the worker id from which the RPC originated. rpc::worker_id_t fromWorkerId() const; + // Retrieve the device map. + const std::unordered_map& deviceMap(); + private: // WorkerId from which this RPC originated. This is necessary for knowing // which worker we need to contact during the backward pass. @@ -83,6 +88,9 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase { // Tensors part of the wrappedRpc that need to be considered for autograd. std::vector tensors_; + + // Device mapping for tensors that are sent across an RPC to another node. + std::unordered_map deviceMap_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index 08bb99471686..747a958948a4 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -52,7 +52,8 @@ void addSendRpcBackward( ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, - rpc::worker_id_t fromWorkerId) { + rpc::worker_id_t fromWorkerId, + const std::unordered_map& deviceMap) { // Initialize autograd context if necessary. auto& autogradContainer = DistAutogradContainer::getInstance(); auto autogradContext = @@ -61,7 +62,7 @@ ContextPtr addRecvRpcBackward( if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) { // Attach the tensors as inputs to the autograd function. auto grad_fn = std::make_shared( - autogradMetadata, autogradContext, fromWorkerId); + autogradMetadata, autogradContext, fromWorkerId, deviceMap); for (auto& tensor : tensors) { if (tensor.requires_grad()) { torch::autograd::set_history(tensor, grad_fn); @@ -102,7 +103,8 @@ Message getMessageWithAutograd( const rpc::worker_id_t dstId, torch::distributed::rpc::Message&& wrappedRpcMsg, MessageType msgType, - bool forceGradRecording) { + bool forceGradRecording, + const std::unordered_map& deviceMap) { auto& autogradContainer = DistAutogradContainer::getInstance(); // If there is no valid context and no tensor requires grads, send original @@ -125,7 +127,8 @@ Message getMessageWithAutograd( RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_, msgType, autogradMetadata, - std::move(wrappedRpcMsg)); + std::move(wrappedRpcMsg), + deviceMap); if (tensorsRequireGrad) { // Record autograd information for 'send'. @@ -149,7 +152,8 @@ std::shared_ptr sendMessageWithAutograd( dst.id_, std::move(wrappedRpcMsg), MessageType::FORWARD_AUTOGRAD_REQ, - forceGradRecording); + forceGradRecording, + agent.getDeviceMap(dst)); std::shared_ptr fut; // If profiler is enabled, wrap this message with profiling metadata that will diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 07ba45ed60d7..013558252fc2 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -30,7 +30,8 @@ TORCH_API void addSendRpcBackward( TORCH_API ContextPtr addRecvRpcBackward( const AutogradMetadata& autogradMetadata, std::vector& tensors, - rpc::worker_id_t fromWorkerId); + rpc::worker_id_t fromWorkerId, + const std::unordered_map& deviceMap); // This method is a wrapper utility used internally to wrap autograd info // and attach autograd function for each type of rpc call if it has valid @@ -42,7 +43,9 @@ TORCH_API rpc::Message getMessageWithAutograd( const rpc::worker_id_t dstId, rpc::Message&& wrappedRpcMsg, rpc::MessageType msgType, - bool forceGradRecording = false); + bool forceGradRecording = false, + const std::unordered_map& deviceMap = + {}); // Send message after autograd checking TORCH_API std::shared_ptr diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 9c1a703cfa6d..3cd940f3ee49 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -290,7 +290,8 @@ void ProcessGroupAgent::shutdownImpl() { std::shared_ptr ProcessGroupAgent::send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds) { + const float rpcTimeoutSeconds, + const std::unordered_map& deviceMap) { // Throw if we previously encountered an exception in ::listenLoop. { std::unique_lock guard(listenLoopExceptionMutex_); diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 8d2471a7d113..d1d957a66562 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -91,7 +91,9 @@ class TORCH_API ProcessGroupAgent : public RpcAgent { std::shared_ptr send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; + const float rpcTimeoutSeconds = kUnsetRpcTimeout, + const std::unordered_map& deviceMap = + {}) override; // put SendWork into a queue and notify the worker thread virtual void enqueueSend(SendWork work); diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 09c56dc960c9..46192d7eb317 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -345,11 +345,20 @@ void RequestCallbackNoPython::processForwardAutogradReq( const std::shared_ptr& responseFuture) const { auto& rpcWithAutograd = static_cast(rpc); + // Need to reverse the device map for the backward pass of distributed + // autograd. + std::unordered_map reverseDeviceMap; + for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { + reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); + } + + // Attach 'recv' autograd function. auto autogradContext = addRecvRpcBackward( rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors(), - rpcWithAutograd.fromWorkerId()); + rpcWithAutograd.fromWorkerId(), + reverseDeviceMap); // For this recv thread on server side, before processRpc(), // set current_context_id_ to be context_id passed from client. // In this way, if there is nested rpc call in python rpc call, original diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 2033b2b771e2..5c9570bcac1d 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -286,6 +286,12 @@ bool RpcAgent::isGILProfilingEnabled() { return profilingEnabled_.load(); } +std::unordered_map RpcAgent::getDeviceMap( + const WorkerInfo& dest) { + // Default implementation has no device map. + return {}; +} + std::unordered_map RpcAgent::getDebugInfo() { /* This would later include more info other than metrics for eg: may include stack traces for the threads owned by the agent */ diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index bfc6c38c07a1..956af3da899b 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -160,7 +160,9 @@ class TORCH_API RpcAgent { virtual std::shared_ptr send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds = kUnsetRpcTimeout) = 0; + const float rpcTimeoutSeconds = kUnsetRpcTimeout, + const std::unordered_map& deviceMap = + {}) = 0; // Retries sending the message up to maxRetries times until an ACK is // receieved. The duration between consecutive sends is increased over @@ -259,6 +261,10 @@ class TORCH_API RpcAgent { // Get the type resolver std::shared_ptr getTypeResolver(); + // Retrieves the device map for the provided destination worker. + virtual std::unordered_map getDeviceMap( + const WorkerInfo& dest); + protected: const WorkerInfo workerInfo_; const std::unique_ptr cb_; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 4f56c916cb98..518fc72e8304 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -14,6 +14,12 @@ #include #endif +#if TENSORPIPE_HAS_SHM_TRANSPORT +// Needed for ::getpid(), which is used to create a unique address. +#include +#include +#endif + namespace torch { namespace distributed { namespace rpc { @@ -33,6 +39,42 @@ const std::string kClientActiveCalls = "agent.client_active_calls"; const std::string kServerActiveCalls = "agent.server_active_calls"; const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls"; +std::vector getDevicesForTensors( + const std::vector& tensors, + const tensorpipe::DeviceMap& deviceMap, + const std::string& remoteName) { + // If the deviceMap is overridden, use that instead. + const auto errStr = c10::str( + "TensorPipe RPC backend only supports CPU tensors by default, please " + "move your tensors to CPU before sending them over RPC, or call " + "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly " + "configure device mapping. ", + "Request device mapping is not available for destination ", + remoteName); + std::vector deviceIndices; + deviceIndices.reserve(tensors.size()); + bool hasCudaTensor = false; + for (const auto& t : tensors) { + if (t.device().is_cpu()) { + deviceIndices.push_back(-1); + } else { + const auto deviceIter = deviceMap.find(t.device().index()); + TORCH_CHECK( + deviceIter != deviceMap.end(), + errStr, + " for device ", + t.device(), + " but received a tensor on that device."); + deviceIndices.push_back(deviceIter->second); + hasCudaTensor = true; + } + } + if (!hasCudaTensor) { + deviceIndices.clear(); + } + return deviceIndices; +} + } // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) @@ -541,7 +583,9 @@ void TensorPipeAgent::pipeWrite( Message&& rpcMessage, std::vector&& devices, std::shared_ptr ctx, - std::function fn) noexcept { + std::function fn, + const std::unordered_map& + deviceMap) noexcept { tensorpipe::Message tpMessage; TensorpipeWriteBuffers tpBuffers; @@ -581,7 +625,7 @@ void TensorPipeAgent::sendCompletedResponseMessage( responseMessage.setId(messageId); std::vector devices; try { - devices = getDevicesForTensors(pipe->getRemoteName(), responseMessage); + devices = getDevicesForRemote(pipe->getRemoteName(), responseMessage); } catch (const std::exception& e) { responseMessage = createExceptionResponse(e.what(), messageId); } @@ -715,7 +759,8 @@ void TensorPipeAgent::respond(std::shared_ptr& pipe) { std::shared_ptr TensorPipeAgent::send( const WorkerInfo& toWorkerInfo, Message&& requestMessage, - const float rpcTimeoutSeconds) { + const float rpcTimeoutSeconds, + const std::unordered_map& deviceMap) { TORCH_CHECK( requestMessage.isRequest(), "TensorPipeAgent::send(..) is only for sending requests."); @@ -755,8 +800,15 @@ std::shared_ptr TensorPipeAgent::send( // Get devices for tensors in the request message. This can throw if device // maps are not configured properly for this request. - auto devices = - getDevicesForTensors(clientPipe.pipe_->getRemoteName(), requestMessage); + std::vector devices; + if (deviceMap.empty()) { + devices = + getDevicesForRemote(clientPipe.pipe_->getRemoteName(), requestMessage); + } else { + // If deviceMap is specified, use that instead. + devices = getDevicesForTensors( + requestMessage.tensors(), deviceMap, clientPipe.pipe_->getRemoteName()); + } futureResponseMessage->jitFuture->addCallback([this]() { TORCH_INTERNAL_ASSERT( @@ -900,7 +952,8 @@ std::shared_ptr TensorPipeAgent::send( std::move(ctx)); } }); - }); + }, + deviceMap); return futureResponseMessage->jitFuture; } @@ -1190,7 +1243,7 @@ void TensorPipeAgent::markFutureWithError( } } -std::vector TensorPipeAgent::getDevicesForTensors( +std::vector TensorPipeAgent::getDevicesForRemote( const std::string& remoteName, const Message& message) const { const auto& deviceMaps = @@ -1216,32 +1269,18 @@ std::vector TensorPipeAgent::getDevicesForTensors( } return {}; } else { - std::vector deviceIndices; - deviceIndices.reserve(message.tensors().size()); - const auto& deviceMap = iter->second; - bool hasCudaTensor = false; - for (const auto& t : message.tensors()) { - if (t.device().is_cpu()) { - deviceIndices.push_back(-1); - } else { - const auto deviceIter = deviceMap.find(t.device().index()); - TORCH_CHECK( - deviceIter != deviceMap.end(), - errStr, - " for device ", - t.device(), - " but received a tensor on that device."); - deviceIndices.push_back(deviceIter->second); - hasCudaTensor = true; - } - } - if (!hasCudaTensor) { - deviceIndices.clear(); - } - return deviceIndices; + return getDevicesForTensors(message.tensors(), iter->second, errStr); } } +tensorpipe::DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dest) { + auto it = opts_.deviceMaps.find(dest.name_); + if (it == opts_.deviceMaps.end()) { + return {}; + } + return it->second; +} + size_t TensorPipeAgent::timeoutMapSize() { std::unique_lock lock(timeoutMapMutex_); return timeoutMap_.size(); diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index b4d8796aeede..078750385538 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -185,7 +185,9 @@ class TensorPipeAgent : public RpcAgent { std::shared_ptr send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds = kUnsetRpcTimeout) override; + const float rpcTimeoutSeconds = kUnsetRpcTimeout, + const std::unordered_map& deviceMap = + {}) override; // join() and sync() would be deprecated - // https://github.com/pytorch/pytorch/issues/27647 @@ -209,6 +211,8 @@ class TensorPipeAgent : public RpcAgent { void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override; + tensorpipe::DeviceMap getDeviceMap(const WorkerInfo& dest) override; + using NetworkDataDict = std::unordered_map; @@ -252,7 +256,8 @@ class TensorPipeAgent : public RpcAgent { Message&& message, std::vector&& devices, std::shared_ptr ctx, - std::function) noexcept; + std::function, + const tensorpipe::DeviceMap& deviceMap = {}) noexcept; // Callback of listener accept() void onListenerAccepted( @@ -279,7 +284,7 @@ class TensorPipeAgent : public RpcAgent { uint64_t requestSize, const std::string& destWorkerName); - inline std::vector getDevicesForTensors( + inline std::vector getDevicesForRemote( const std::string& remoteName, const Message& message) const; diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 1d17e4451372..9757e0971d2b 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -91,7 +91,8 @@ std::tuple tensorpipeSerialize( // Enforce memory copy if tensor is created from torch::from_blob, means // that the tensor doesn't own the memory. std::string metadata = - deviceIndices.empty() ? "" : std::to_string(deviceIndices[i]); + deviceIndices.empty() || deviceIndices[i] == -1 + ? "" : std::to_string(deviceIndices[i]); if (!tensorData.storageHasDeleter()) { std::vector storageData( diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp index 57dbef3a549b..870f9702ee0e 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp @@ -59,7 +59,8 @@ std::unordered_map> FaultyProcessGroupAgent:: std::shared_ptr FaultyProcessGroupAgent::send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds) { + const float rpcTimeoutSeconds, + const std::unordered_map& deviceMap) { // We only fail control messages that have been specified by the test case. // For all other messages, we just send them without any failures. if (!shouldFailMessage(message.type())) { diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h index 8cbe4c9a137d..ce8fee558274 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h +++ b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h @@ -46,8 +46,9 @@ class FaultyProcessGroupAgent : public ProcessGroupAgent { std::shared_ptr send( const WorkerInfo& to, Message&& message, - const float rpcTimeoutSeconds = - torch::distributed::rpc::kUnsetRpcTimeout) override; + const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, + const std::unordered_map& deviceMap = + {}) override; protected: // This function checks the messageTypesToFail_ to determine whether to use diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 0f137a72a252..d643ab87b8ea 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -174,11 +174,19 @@ std::unique_ptr deserializeResponse( RpcCommandBase& rpc = *rpcPtr; auto& rpcWithAutograd = static_cast(rpc); + // Need to reverse the device map for the backward pass of distributed + // autograd. + std::unordered_map reverseDeviceMap; + for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { + reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); + } + // Attach 'recv' autograd function. addRecvRpcBackward( rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors(), - rpcWithAutograd.fromWorkerId()); + rpcWithAutograd.fromWorkerId(), + reverseDeviceMap); wrappedMsgType = rpcWithAutograd.wrappedMessageType(); diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index a14dcd611dd8..aca08331183c 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -97,5 +97,28 @@ Value* addInputToBlock(Block* block) { return block->addInput(); } +Node* createONNXUnsqueeze( + Graph* graph, + Node* n_to_insert_before, + Value* input, + int axis, + int opset_version) { + Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1); + unsqueeze_node->addInput(input); + unsqueeze_node->insertBefore(n_to_insert_before); + if (opset_version >= OPSET_VERSION_13) { + // ONNX spec sets `axes` as input for opset >= 13. + Node* unsqueeze_axes = graph->create(onnx::Constant, 1); + unsqueeze_axes->insertBefore(unsqueeze_node); + unsqueeze_axes->t_( + attr::value, at::unsqueeze(at::scalar_to_tensor(at::Scalar(axis)), 0)); + unsqueeze_node->addInput(unsqueeze_axes->output()); + } else { + // ONNX spec sets `axes` as attribute for opset < 13. + unsqueeze_node->is_(attr::axes, {0}); + } + return unsqueeze_node; +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/helper.h b/torch/csrc/jit/passes/onnx/helper.h index e27909ff6362..43989bd8e6c3 100644 --- a/torch/csrc/jit/passes/onnx/helper.h +++ b/torch/csrc/jit/passes/onnx/helper.h @@ -13,6 +13,7 @@ static const int OPSET_VERSION_9 = 9; static const int OPSET_VERSION_10 = 10; static const int OPSET_VERSION_11 = 11; static const int OPSET_VERSION_12 = 12; +static const int OPSET_VERSION_13 = 13; using ValueToParamPairMap = std::map>; @@ -33,5 +34,13 @@ Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef inputs); Value* addInputToBlock(Block* block); TORCH_API c10::optional ONNXTypeToATenType(int32_t onnx_type); + +Node* createONNXUnsqueeze( + Graph* graph, + Node* n_to_insert_before, + Value* input, + int axis, + int opset_version); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index d488201a8f80..8aa07332cc65 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -416,10 +416,8 @@ void fixDefaultRNNState( batch_size->addInput(shape_of_input->outputs()[0]); batch_size->addInput(gather_indices->outputs()[0]); - Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1); - unsqueezed_batch_size->insertBefore(n); - unsqueezed_batch_size->addInput(batch_size->outputs()[0]); - unsqueezed_batch_size->is_(attr::axes, {0}); + Node* unsqueezed_batch_size = + createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version); Node* hidden_size = graph->create(onnx::Constant, 1); hidden_size->insertBefore(n); @@ -440,10 +438,8 @@ void fixDefaultRNNState( ? 2 : 1))); - Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1); - unsqueezed_num_directions->insertBefore(n); - unsqueezed_num_directions->addInput(num_directions->outputs()[0]); - unsqueezed_num_directions->is_(attr::axes, {0}); + Node* unsqueezed_num_directions = createONNXUnsqueeze( + graph, n, num_directions->outputs()[0], 0, opset_version); Node* concated_dims = graph->create(onnx::Concat, 1); concated_dims->insertBefore(n); @@ -555,6 +551,65 @@ static void replaceInputWithList(Node* node, size_t i, ArrayRef to) { } } +static void eraseListConstruct(Block* block, int opset_version); + +static void eraseListConstruct(Node* n, int opset_version) { + for (auto b : n->blocks()) { + eraseListConstruct(b, opset_version); + } + std::vector>> replacements; + + auto block = n->owningBlock(); + size_t i = 0; + for (auto* input : n->inputs()) { + if (input->node()->kind() == prim::ListConstruct) { + auto* lc_node = input->node(); + TypePtr elem = + lc_node->output()->type()->cast()->getElementType(); + if (elem->cast()) { + // ListConstruct Int[] output case, we need to transform to ONNX + // Concat to ensure the output is a single tensor(dynamic) type in + // order to be consumed as inputs + std::vector unsqueezed; + Graph* g = block->owningGraph(); + for (auto* input : lc_node->inputs()) { + Node* unsqueezed_node = + createONNXUnsqueeze(g, lc_node, input, 0, opset_version); + unsqueezed.emplace_back(unsqueezed_node->output()); + } + Node* concat_node = g->create(onnx::Concat, 1); + concat_node->i_(attr::axis, 0); + for (auto v : unsqueezed) { + concat_node->addInput(v); + } + concat_node->insertBefore(lc_node); + + // make concat node output as new input, then ListConstruct should + // become dead + replacements.emplace_back( + i, std::vector({concat_node->output()})); + + } else { + if (opset_version >= OPSET_VERSION_11) { + c10::Symbol seq_node_kind = lc_node->inputs().size() > 0 + ? onnx::SequenceConstruct + : onnx::SequenceEmpty; + Node* seq_node = block->owningGraph()->create( + seq_node_kind, {lc_node->inputs()}, 1); + seq_node->insertBefore(lc_node); + seq_node->output()->copyMetadata(lc_node->output()); + lc_node->replaceAllUsesWith(seq_node); + } + } + } + i++; + } + + for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) { + replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr)); + } +} + static void eraseListConstruct(Block* block, int opset_version) { // TODO: Fix this pass/maybe get rid of this part. // Tensor lists might be used for meshgrid and such ops as well. @@ -563,71 +618,9 @@ static void eraseListConstruct(Block* block, int opset_version) { Node* n = *it; ++it; - for (auto b : n->blocks()) { - eraseListConstruct(b, opset_version); - } - std::vector>> replacements; - - size_t i = 0; - for (auto* input : n->inputs()) { - if (input->node()->kind() == prim::ListConstruct) { - auto* lc_node = input->node(); - TypePtr elem = - lc_node->output()->type()->cast()->getElementType(); - if (elem->cast()) { - // ListConstruct Int[] output case, we need to transform to ONNX - // Concat to ensure the output is a single tensor(dynamic) type in - // order to be consumed as inputs - std::vector unsqueezed; - Graph* g = block->owningGraph(); - for (auto* input : lc_node->inputs()) { - Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1); - unsqueezed_node->insertBefore(lc_node); - unsqueezed_node->addInput(input); - unsqueezed_node->is_(attr::axes, {0}); - unsqueezed.emplace_back(unsqueezed_node->output()); - } - Node* concat_node = g->create(onnx::Concat, 1); - concat_node->i_(attr::axis, 0); - for (auto v : unsqueezed) { - concat_node->addInput(v); - } - concat_node->insertBefore(lc_node); - - // make concat node output as new input, then ListConstruct should - // become dead - replacements.emplace_back( - i, std::vector({concat_node->output()})); - - } else { - if (opset_version < OPSET_VERSION_11) { - // Tensor lists are used mostly for inputs to cat/stack. They are - // already handled in those symbolics, and should become dead - // afterwards. - replacements.emplace_back( - i, - std::vector( - lc_node->inputs().begin(), lc_node->inputs().end())); - } else { - c10::Symbol seq_node_kind = lc_node->inputs().size() > 0 - ? onnx::SequenceConstruct - : onnx::SequenceEmpty; - Node* seq_node = block->owningGraph()->create( - seq_node_kind, {lc_node->inputs()}, 1); - seq_node->insertBefore(lc_node); - seq_node->output()->copyMetadata(lc_node->output()); - lc_node->replaceAllUsesWith(seq_node); - } - } - } - i++; - } - - for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); - ++ritr) { - replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr)); - } + eraseListConstruct(n, opset_version); } + eraseListConstruct(block->return_node(), opset_version); } // For ops such as meshgrid where output is a list of Tensors diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index bc26183a25bb..c9b42d76973a 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -571,29 +571,27 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { << "Warning: ONNX Preprocess - Removing mutation on block inputs. " << "This changes graph semantics." << std::endl; + Node* newNode = nullptr; if (input->type()->kind() == TypeKind::ListType) { // Create an aten::list to clone the list in graph inputs - auto newNode = node->owningGraph()->create(aten::list, 1); - newNode->output()->copyMetadata(input); + newNode = node->owningGraph()->create(aten::list, 1); + newNode->output()->setType(input->type()); newNode->addInput(input); - newNode->insertBefore(node); - node->replaceInput(index, newNode->output()); - input->replaceAllUsesAfterNodeWith(node, newNode->output()); + b->prependNode(newNode); } else { // Create an aten::clone to clone the tensor in graph inputs - auto newNode = node->owningGraph()->create(aten::clone, 1); - newNode->output()->copyMetadata(input); + newNode = node->owningGraph()->create(aten::clone, 1); + newNode->output()->setType(input->type()); newNode->addInput(input); auto* noneNode = node->owningGraph()->create(prim::Constant); noneNode->output()->setType(NoneType::get()); newNode->addInput(noneNode->output()); - - newNode->insertBefore(node); + b->prependNode(newNode); noneNode->insertBefore(newNode); - node->replaceInput(index, newNode->output()); - input->replaceAllUsesAfterNodeWith(node, newNode->output()); } + node->replaceInput(index, newNode->output()); + input->replaceAllUsesAfterNodeWith(node, newNode->output()); } } } diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index bc1cabf81ddf..07d9340ade3a 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -201,7 +201,10 @@ bool IsSupportedNode(const Node* n) { return true; } -Value* CloneValueFromListConstruct(Value* v, std::shared_ptr n_graph) { +Value* CloneValueFromListConstruct( + Value* v, + std::shared_ptr n_graph, + int opset_version) { auto lc_node = v->node(); TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct); // In jit/passes/onnx/peephole.cpp::eraseListConstruct, @@ -221,12 +224,10 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr n_graph) { // order to be consumed as inputs std::vector unsqueezed; for (auto* input : lc_node->inputs()) { - Node* unsqueezed_node = - n_graph->insertNode(n_graph->create(::c10::onnx::Unsqueeze, 1)); auto new_input = n_graph->addInput(); new_input->copyMetadata(input); - unsqueezed_node->addInput(new_input); - unsqueezed_node->is_(attr::axes, {0}); + Node* unsqueezed_node = createONNXUnsqueeze( + n_graph.get(), n_graph->return_node(), new_input, 0, opset_version); unsqueezed.emplace_back(unsqueezed_node->output()); } Node* concat_node = @@ -258,34 +259,51 @@ Value* CloneValueFromListConstruct(Value* v, std::shared_ptr n_graph) { } // Clone the node n for the new graph. -Node* CloneNodeToGraph(Node* n, std::shared_ptr n_graph) { - auto clone_node = n_graph->createClone(n, [&n_graph](Value* v) { - auto v_n = v->node(); - switch (v_n->kind()) { - case ::c10::onnx::Constant: { - // Clone the input if it is constant. - auto constant_n = n_graph->insertNode( - n_graph->createClone(v_n, [](Value* v) { return v; })); - return constant_n->output(); - } - case ::c10::prim::ListConstruct: { - return CloneValueFromListConstruct(v, n_graph); - } - case ::c10::prim::PackPadded: { - auto input = n_graph->addInput(); - input->copyMetadata(v_n->input(0)); - return input; - } - default: { - // If the input is not constant, we cannot depend on its value - // in shape inference. Set it to graph input in the new graph, - // and copy over metadata, such as datatype and shape. - auto input = n_graph->addInput(); - input->copyMetadata(v); - return input; - } - } - }); +Node* CloneNodeToGraph( + Node* n, + std::shared_ptr n_graph, + const ParamMap& params_dict, + int opset_version) { + auto vals_to_params_map = + buildValueToParamsMap(n->owningGraph()->block(), params_dict); + auto clone_node = n_graph->createClone( + n, [&n_graph, &vals_to_params_map, opset_version](Value* v) { + auto v_n = v->node(); + switch (v_n->kind()) { + case ::c10::onnx::Constant: { + // Clone the input if it is constant. + auto constant_n = n_graph->insertNode( + n_graph->createClone(v_n, [](Value* v) { return v; })); + return constant_n->output(); + } + case ::c10::prim::ListConstruct: { + return CloneValueFromListConstruct(v, n_graph, opset_version); + } + case ::c10::prim::PackPadded: { + auto input = n_graph->addInput(); + input->copyMetadata(v_n->input(0)); + return input; + } + default: { + if (vals_to_params_map.find(v) != vals_to_params_map.end()) { + // If the input is a parameter, insert a constant of its value as + // input. + auto val = vals_to_params_map.find(v)->second.second.toTensor(); + return n_graph + ->insertNode(n_graph->create(::c10::onnx::Constant) + ->t_(attr::value, val)) + ->output(); + } else { + // If the input is not constant, we cannot depend on its value + // in shape inference. Set it to graph input in the new graph, + // and copy over metadata, such as datatype and shape. + auto input = n_graph->addInput(); + input->copyMetadata(v); + return input; + } + } + } + }); return clone_node; } @@ -433,19 +451,25 @@ void FetchBlockInputMetadataFromParent(Block* b) { } } -void ONNXShapeTypeInference(Block* b, int opset_version) { +void ONNXShapeTypeInference( + Block* b, + const ParamMap& params_dict, + int opset_version) { FetchBlockInputMetadataFromParent(b); for (auto n : b->nodes()) { for (auto subblock : n->blocks()) { - ONNXShapeTypeInference(subblock, opset_version); + ONNXShapeTypeInference(subblock, params_dict, opset_version); } - ONNXShapeTypeInference(n, opset_version); + ONNXShapeTypeInference(n, params_dict, opset_version); } } } // namespace -void ONNXShapeTypeInference(Node* n, int opset_version) { +void ONNXShapeTypeInference( + Node* n, + const ParamMap& params_dict, + int opset_version) { GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); if (!IsSupportedNode(n)) { @@ -454,7 +478,7 @@ void ONNXShapeTypeInference(Node* n, int opset_version) { // Create a Graph containing only the single node n. // This graph is later converted to ONNX to run shape inference. auto n_graph = std::make_shared(); - auto clone_node = CloneNodeToGraph(n, n_graph); + auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version); n_graph->insertNode(clone_node); // Register all node outputs as graph outputs. @@ -485,12 +509,16 @@ void ONNXShapeTypeInference(Node* n, int opset_version) { } catch (std::runtime_error& ex) { // TODO: include this as warning once we have a more consolidated warning // system. + GRAPH_DEBUG( + "ONNX shape inference fails with: ", + ex.what(), + " on graph: ", + n_graph->toString()); const char shape_err[] = "ShapeInferenceError"; const char type_err[] = "TypeInferenceError"; if ((strstr(ex.what(), shape_err) == NULL) && (strstr(ex.what(), type_err) == NULL)) throw; - GRAPH_DEBUG("ONNX shape inference fails with: ", ex.what()); } GRAPH_DEBUG( "ONNX graph after shape inference: ", prettyPrint(*model_proto)); @@ -690,8 +718,11 @@ void ONNXAssignOutputShape( Py_DECREF(py_obj); } -void ONNXShapeTypeInference(std::shared_ptr& graph, int opset_version) { - ONNXShapeTypeInference(graph->block(), opset_version); +void ONNXShapeTypeInference( + std::shared_ptr& graph, + const ParamMap& params_dict, + int opset_version) { + ONNXShapeTypeInference(graph->block(), params_dict, opset_version); } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index bac7e2439ca9..69fbff1d175c 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace torch { @@ -34,7 +35,10 @@ TORCH_API void ONNXAssignOutputShape( // The node must have ONNX namespace, and is valid ONNX node accroding to spec. // On successful ONNX shape inference runs, the function updates output types of // n with inferred shape and type. Otherwise n is unchanged. -TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version); +TORCH_API void ONNXShapeTypeInference( + Node* n, + const ParamMap& params_dict, + int opset_version); // Utilize ONNX Shape Inference for graph. // Internally calls ONNXShapeTypeInference for each node, to achieve more @@ -42,6 +46,7 @@ TORCH_API void ONNXShapeTypeInference(Node* n, int opset_version); // the entire graph. TORCH_API void ONNXShapeTypeInference( std::shared_ptr& g, + const ParamMap& params_dict, int opset_version); } // namespace jit diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 2a91bd497e7b..197af7179361 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -210,13 +210,17 @@ void initJITBindings(PyObject* module) { PrepareInplaceOpsForONNX) .def( "_jit_pass_onnx_node_shape_type_inference", - [](Node* n, int opset_version) { - ONNXShapeTypeInference(n, opset_version); + [](Node* n, + std::map& params_dict, + int opset_version) { + ONNXShapeTypeInference(n, params_dict, opset_version); }) .def( "_jit_pass_onnx_graph_shape_type_inference", - [](std::shared_ptr& graph, int opset_version) { - ONNXShapeTypeInference(graph, opset_version); + [](std::shared_ptr& graph, + std::map& params_dict, + int opset_version) { + ONNXShapeTypeInference(graph, params_dict, opset_version); }) .def("_jit_pass_onnx_set_dynamic_input_shape", ONNXSetDynamicInputShape) .def("_jit_pass_fuse", FuseGraph) diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 92aad34d3b7d..9645051e6a9a 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -183,11 +183,18 @@ class TORCH_API Buf : public ExprNode { Buf(const std::string& name_hint, const std::vector& dims, - Dtype dtype) - : Buf(new Var(name_hint, kHandle), dims, dtype) {} + Dtype dtype, + const Expr* initializer = nullptr) + : Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {} - Buf(const Var* var, const std::vector& dims, Dtype dtype) - : ExprNodeBase(dtype, kPrimitive), base_handle_(var), dims_(dims) { + Buf(const Var* var, + const std::vector& dims, + Dtype dtype, + const Expr* initializer = nullptr) + : ExprNodeBase(dtype, kPrimitive), + base_handle_(var), + dims_(dims), + initializer_(initializer) { TORCH_CHECK(var); } @@ -207,9 +214,14 @@ class TORCH_API Buf : public ExprNode { dims_ = dims; }; + const Expr* initializer() const { + return initializer_; + }; + private: const Var* base_handle_; std::vector dims_; + const Expr* initializer_; }; class TORCH_API BufHandle : public ExprHandle { diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp index 1df2f96671df..96fb11d3a982 100644 --- a/torch/csrc/jit/tensorexpr/ir_printer.cpp +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -596,19 +596,15 @@ std::string to_string(const Tensor* t) { return "(null tensor)\n"; } std::ostringstream oss; - if (!t->body()) { - oss << "Tensor " << t->buf()->name_hint() << " = " << *t->ElementStmt() - << "\n"; - return oss.str(); - } - oss << "Tensor " << t->buf()->name_hint() << "("; - for (size_t i = 0; i < t->ndim(); i++) { + // TODO: move this to Buf printer + oss << "Tensor " << t->buf()->name_hint() << "["; + for (size_t i = 0; i < t->buf()->ndim(); i++) { if (i != 0) { oss << ", "; } - oss << *t->arg(i) << "[" << *t->dim(i) << "]"; + oss << *t->buf()->dim(i); } - oss << ") = " << *t->body() << "\n"; + oss << "]:\n" << *t->stmt() << "\n"; return oss.str(); } } // namespace std diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 910de06b2693..1da0ab8beefb 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -119,7 +119,7 @@ size_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { } static at::ScalarType tensorType(Tensor* t) { - return static_cast(t->body()->dtype().scalar_type()); + return static_cast(t->buf()->dtype().scalar_type()); } static std::vector computeIndicesToBroadcast( @@ -608,7 +608,7 @@ std::vector TensorExprKernel::valueShape( if (it == tensors_.end()) { return {}; } - return ExprVectorToExprHandleVector(it->second->dims()); + return ExprVectorToExprHandleVector(it->second->buf()->dims()); } Tensor* TensorExprKernel::computeOneOperand( @@ -1125,7 +1125,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { case aten::type_as: { auto const& n = v->node(); Tensor* rhs = tensors_.at(n->inputs()[1]->unique()); - auto dtype = rhs->body()->dtype(); + auto dtype = rhs->buf()->dtype(); return computeOneOperand( "aten_type_as", v, [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); @@ -1350,8 +1350,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } break; case aten::lgamma: { - return computeOneOperand( - "aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); }); + return computeOneOperand("aten_lgamma", v, [](const ExprHandle& a) { + return lgamma(promoteIntegerToDefaultType(a)); + }); } break; case prim::ConstantChunk: { diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 71869d1d33d7..7d6e208291ba 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -24,6 +24,11 @@ namespace torch { namespace jit { namespace tensorexpr { +LoopNest::LoopNest(const LoopNest& other) + : root_stmt_(Stmt::clone(other.root_stmt_)), + output_bufs_(other.output_bufs_), + intermediate_bufs_(other.intermediate_bufs_) {} + class FunctionCallUseCount : public IRVisitor { public: std::unordered_map findUses(Stmt* s) { @@ -424,11 +429,7 @@ class DepTracker : public IRVisitor { public: std::vector findUsedTensors(Tensor* tensor) { used_tensors.clear(); - if (tensor->body()) { - tensor->body()->accept(this); - } else { - tensor->ElementStmt()->accept(this); - } + tensor->stmt()->accept(this); return used_tensors; } @@ -508,7 +509,12 @@ LoopNest::LoopNest(const std::vector& output_tensors) { std::vector loops; for (Tensor* t : tensors_to_compute) { - Stmt* loop = lowerToStmt(t); + Stmt* loop = t->stmt(); + if (loop->get_parent()) { + std::cerr << "Error: creating a loopnest from already used Tensors\n"; + loops = {}; + break; + } // Flatten initializers. if (Block* block = dynamic_cast(loop)) { for (auto* s : block->stmts()) { @@ -532,49 +538,6 @@ LoopNest::LoopNest(const std::vector& output_tensors) { } } -Stmt* LoopNest::lowerToStmt(Tensor* t) { - Stmt* body = t->ElementStmt(); - - // If this Tensor has no functional body, it already has its axes expanded. - if (nullptr == t->body()) { - return body; - } - - if (t->ndim() == 0 && t->reduce_ndim() == 0) { - return body; - } - - const Expr* initializer = t->initializer(); - if (initializer) { - buf_initializers_[t->buf()] = initializer; - } - - std::vector indices(t->args().begin(), t->args().end()); - - if (t->reduce_ndim() > 0) { - for (size_t i = 0; i < t->reduce_ndim(); i++) { - // Going in reverse order: from innermost loop to the outermost - size_t dim_index = t->reduce_ndim() - i - 1; - body = new For( - t->reduce_arg(dim_index), - new IntImm(0), - t->reduce_dim(dim_index), - body); - } - if (initializer) { - Store* init = new Store(t->buf(), indices, initializer, new IntImm(1)); - body = new Block({init, body}); - } - } - - for (size_t i = 0; i < t->ndim(); i++) { - // Going in reverse order: from innermost loop to the outermost - size_t dim_index = t->ndim() - i - 1; - body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body); - } - return body; -} - class FunctionInliner : public IRMutator { public: FunctionInliner(Store* producer, std::unordered_set outputs) @@ -587,6 +550,7 @@ class FunctionInliner : public IRMutator { throw std::logic_error("cannot inline Buf with compound indices"); } index_vars_.insert(index_var); + producer_index_vars_.push_back(index_var); } } @@ -606,9 +570,9 @@ class FunctionInliner : public IRMutator { } std::vector index_vars; - TORCH_INTERNAL_ASSERT(buf->ndim() == t->args().size()); + TORCH_INTERNAL_ASSERT(buf->ndim() == producer_index_vars_.size()); for (size_t i = 0; i < buf->ndim(); i++) { - const Var* func_callee_arg = dynamic_cast(t->arg(i)); + const Var* func_callee_arg = producer_index_vars_.at(i); const Expr* func_caller_param = v->param(i); auto iter = inline_mapping_.find(func_callee_arg); if (iter != inline_mapping_.end()) { @@ -729,6 +693,7 @@ class FunctionInliner : public IRMutator { // Index Vars present in the producer. std::unordered_set index_vars_; + std::vector producer_index_vars_; std::unordered_map inline_mapping_; @@ -2352,8 +2317,9 @@ void LoopNest::rfactor( } std::vector new_dims = {}; - Buf* tmp_buf = - new Buf(new Var("tmp_buf", kHandle), new_dims, reduce_op->dtype()); + const Expr* init = reduce_op->accumulator()->initializer(); + TORCH_INTERNAL_ASSERT(init); + Buf* tmp_buf = new Buf("tmp_buf", new_dims, reduce_op->dtype(), init); auto old_acc = reduce_op->accumulator(); auto new_inner = reduce_op->reduce_args(); @@ -2425,26 +2391,17 @@ void LoopNest::rfactor( throw std::runtime_error("TODO: enable non-root insertion points"); } - // From this point forward any errors cannot be handled silently. - auto init_it = buf_initializers_.find(reduce_op->accumulator()); - if (init_it != buf_initializers_.end()) { - buf_initializers_[tmp_buf] = init_it->second; - Stmt* init_stmt = - new Store(tmp_buf, new_outer, init_it->second, new IntImm(1)); + Stmt* init_stmt = new Store(tmp_buf, new_outer, init, new IntImm(1)); - // Wrap it in any loops lower than the insertion point of the new reduction. - for (auto* ol : output_loops) { - init_stmt = ol->cloneWithNewBody(init_stmt); - } + // Wrap it in any loops lower than the insertion point of the new reduction. + for (auto* ol : output_loops) { + init_stmt = ol->cloneWithNewBody(init_stmt); + } - if (output_contains_target) { - parent_block->insert_stmt_before(init_stmt, new_root_for); - } else { - new_root_for->body()->prepend_stmt(init_stmt); - } + if (output_contains_target) { + parent_block->insert_stmt_before(init_stmt, new_root_for); } else { - // We may support this but not possible now. - throw std::runtime_error("can't rfactor reduction with no initializer\n"); + new_root_for->body()->prepend_stmt(init_stmt); } auto second_buf = dynamic_cast(second_reduce->accumulator()); diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 3f37468c0a80..7c27ca6968a5 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -28,17 +28,16 @@ class TORCH_API LoopNest { LoopNest(const std::vector& output_tensors); // A constructor for building a LoopNest from a pre-baked Stmt and meta-info - // TODO: Nuke intermediate_bufs_ and possibly buf_initializers from here if - // they can be deduced. + // TODO: Nuke intermediate_bufs_ from here if they can be deduced. LoopNest( Stmt* stmt, const std::unordered_set& output_bufs, - const std::unordered_set& intermediate_bufs, - const std::unordered_map& buf_initializers) + const std::unordered_set& intermediate_bufs) : root_stmt_(stmt), output_bufs_(output_bufs), - intermediate_bufs_(intermediate_bufs), - buf_initializers_(buf_initializers) {} + intermediate_bufs_(intermediate_bufs) {} + + LoopNest(const LoopNest& other); Stmt* root_stmt() const { return root_stmt_; @@ -125,7 +124,6 @@ class TORCH_API LoopNest { private: std::vector findAllNeededTensors( const std::vector& tensors); - Stmt* lowerToStmt(Tensor* t); Stmt* insertAllocFree(Stmt* stmt); Stmt* root_stmt_; @@ -133,8 +131,6 @@ class TORCH_API LoopNest { std::unordered_set input_bufs_; std::unordered_set output_bufs_; std::unordered_set intermediate_bufs_; - // Holds the initializer Expr of buffers that have been initialized. - std::unordered_map buf_initializers_; }; TORCH_API Stmt* FlattenIndexes(Stmt* s); diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index d12f6999c8d5..3eec21d13f0b 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -8,19 +8,60 @@ namespace torch { namespace jit { namespace tensorexpr { +Stmt* Tensor::constructStmt( + const std::vector& args, + const Expr* body, + const std::vector& reduce_dims, + const std::vector& reduce_args) const { + std::vector indices(args.begin(), args.end()); + + const Expr* mask = new IntImm(1); + Stmt* s = new Store(buf_, indices, body, mask); + + size_t ndim = buf()->ndim(); + size_t reduce_ndim = reduce_dims.size(); + + if (ndim == 0 && reduce_ndim == 0) { + return s; + } + + const Expr* init_expr = buf()->initializer(); + + if (reduce_ndim > 0) { + for (size_t i = 0; i < reduce_ndim; i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = reduce_ndim - i - 1; + s = new For( + reduce_args[dim_index], new IntImm(0), reduce_dims[dim_index], s); + } + if (init_expr) { + Store* init_stmt = new Store(buf(), indices, init_expr, new IntImm(1)); + s = new Block({init_stmt, s}); + } + } + + for (size_t i = 0; i < ndim; i++) { + // Going in reverse order: from innermost loop to the outermost + size_t dim_index = ndim - i - 1; + s = new For(args[dim_index], new IntImm(0), buf()->dim(dim_index), s); + } + return s; +} + Tensor* Compute( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const std::function&)>& body_func) { std::vector dims; std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); - return new Tensor(func_name, dims, args, body); + const Buf* buf = new Buf(name, dims, body->dtype()); + return new Tensor(buf, args, body); } Tensor* Compute( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const std::function& body_func) { if (dim_args.size() != 1) { @@ -31,11 +72,12 @@ Tensor* Compute( std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarHandle(args[0])).node(); - return new Tensor(func_name, dims, args, body); + const Buf* buf = new Buf(name, dims, body->dtype()); + return new Tensor(buf, args, body); } Tensor* Compute( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const std::function& body_func) { @@ -46,11 +88,12 @@ Tensor* Compute( std::vector args; unpack_dim_args(dim_args, &dims, &args); const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); - return new Tensor(func_name, dims, args, body); + const Buf* buf = new Buf(name, dims, body->dtype()); + return new Tensor(buf, args, body); } Tensor* Compute( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& @@ -64,11 +107,12 @@ Tensor* Compute( const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) .node(); - return new Tensor(func_name, dims, args, body); + const Buf* buf = new Buf(name, dims, body->dtype()); + return new Tensor(buf, args, body); } Tensor* Compute( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const std::function dims; - std::vector args_nodes; - unpack_dim_args(dim_args, &dims, &args_nodes); - auto args = VarVectorToVarHandleVector(args_nodes); - const Expr* body = body_func(args[0], args[1], args[2], args[3]).node(); - return new Tensor(func_name, dims, args_nodes, body); -} - -Stmt* Tensor::ElementStmt() const { - std::vector indices; - for (size_t i = 0; i < buf_->ndim(); i++) { - indices.push_back(args_[i]); - } - - const Expr* mask = new IntImm(1); - Stmt* update_stmt = new Store(buf_, indices, body_, mask); - return update_stmt; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func( + VarHandle(args[0]), + VarHandle(args[1]), + VarHandle(args[2]), + VarHandle(args[3])) + .node(); + const Buf* buf = new Buf(name, dims, body->dtype()); + return new Tensor(buf, args, body); } Tensor* Reduce( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const Reducer& reducer, const Placeholder& buffer, const std::vector& reduce_args) { return Reduce( - func_name, + name, dim_args, reducer, [&](ParameterList& p) { return buffer.load(p); }, @@ -112,13 +150,13 @@ Tensor* Reduce( } Tensor* Reduce( - const std::string& func_name, + const std::string& name, const std::vector& dim_args, const Reducer& reducer, Tensor* tensor, const std::vector& reduce_args) { return Reduce( - func_name, + name, dim_args, reducer, [&](ParameterList& p) { return tensor->call(p); }, diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index e5e399db348b..609b5c30a839 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -14,17 +14,10 @@ namespace tensorexpr { class TORCH_API Tensor : KernelScopedObject { public: - Tensor( - const std::string& name, - const std::vector& dims, - const std::vector& args, - const Expr* body) - // TODO: Function should not create buffers, they should be created - // manually before constructing a function. - : buf_(new Buf(name, dims, body->dtype())), args_(args), body_(body) {} - - Tensor(Buf* buf, const std::vector& args, const Expr* body) - : buf_(buf), args_(args), body_(body) {} + Tensor(const Buf* buf, const std::vector& args, const Expr* body) + : buf_(buf) { + stmt_ = constructStmt(args, body, {}, {}); + } Tensor( const Buf* buf, @@ -32,71 +25,19 @@ class TORCH_API Tensor : KernelScopedObject { const std::vector& reduce_dims, const std::vector& reduce_args, const Expr* body) - : buf_(buf), - args_(args), - body_(body), - reduce_dims_(reduce_dims), - reduce_args_(reduce_args) {} + : buf_(buf) { + stmt_ = constructStmt(args, body, reduce_dims, reduce_args); + } - virtual ~Tensor() {} + Tensor(const Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {} - // Wrappers over accessors to fields of the underlying function - const Expr* body() const { - return body_; - } const Buf* buf() const { return buf_; } - size_t ndim() const { - return buf()->ndim(); - } - const Expr* dim(size_t index) const { - if (index >= ndim()) { - throw out_of_range_index(); - } - return buf()->dim(index); - } - std::vector dims() const { - return buf()->dims(); - } - const Var* arg(size_t index) const { - if (index >= ndim()) { - throw out_of_range_index(); - } - return args_[index]; - } - const std::vector& args() const { - return args_; - } - size_t reduce_ndim() const { - return reduce_dims_.size(); - } - std::vector reduce_dims() const { - return reduce_dims_; - } - std::vector reduce_args() const { - return reduce_args_; - } - const Expr* reduce_dim(size_t index) const { - if (index >= reduce_ndim()) { - throw out_of_range_index(); - } - return reduce_dims_[index]; - } - const Var* reduce_arg(size_t index) const { - if (index >= reduce_ndim()) { - throw out_of_range_index(); - } - return reduce_args_[index]; - } - void initializeTo(const Expr* initializer) { - initializer_ = initializer; - } - const Expr* initializer() const { - return initializer_; + Stmt* stmt() const { + return stmt_; } - virtual Stmt* ElementStmt() const; template inline ExprHandle operator()(const Ts&... ts); @@ -106,30 +47,13 @@ class TORCH_API Tensor : KernelScopedObject { inline ExprHandle call(const Ts&... ts); private: - const Buf* buf_; - std::vector args_; - const Expr* body_; - std::vector reduce_dims_; - std::vector reduce_args_; - - const Expr* initializer_{nullptr}; -}; - -class TORCH_API CompoundTensor : public Tensor { - public: - CompoundTensor( - const Buf* buf, + Stmt* constructStmt( const std::vector& args, - Stmt* stmt) - : Tensor(buf, args, {}, {}, nullptr), stmt_(stmt) {} - - virtual ~CompoundTensor() {} - - Stmt* ElementStmt() const override { - return stmt_; - } + const Expr* body, + const std::vector& reduce_dims, + const std::vector& reduce_args) const; - private: + const Buf* buf_; Stmt* stmt_; }; @@ -268,12 +192,12 @@ Tensor* Reduce( ExprHandle body = Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars)); std::vector output_args(vars.begin(), vars.end()); - Buf* func_result = new Buf(func_name, dims, body.dtype()); + const Expr* init_expr = new Cast(body.dtype(), reducer.initializer()); + Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr); const ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars); Tensor* t = new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op); - t->initializeTo(new Cast(body.dtype(), reducer.initializer())); return t; } diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 187dcfcb87e2..6b0184f3be4c 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -10,6 +10,7 @@ import torch.onnx.utils from functools import wraps +from torch._C import OptionalType # Note [Edit Symbolic Files] @@ -321,9 +322,30 @@ def _interpolate_warning(interpolate_mode): "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" "We recommend using opset 11 and above for models using this operator. ") -def _unsqueeze_helper(g, input, dim): - from torch.onnx.symbolic_opset9 import unsqueeze - return unsqueeze(g, input, dim) +def _unsqueeze_helper(g, input, axes_i): + if _export_onnx_opset_version >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Unsqueeze", input, axes) + else: + return g.op("Unsqueeze", input, axes_i=axes_i) + +def _squeeze_helper(g, input, axes_i): + if _export_onnx_opset_version >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Squeeze", input, axes) + else: + return g.op("Squeeze", input, axes_i=axes_i) + +def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0): + keepdims_i = _maybe_get_const(keepdims_i, 'i') + if _export_onnx_opset_version >= 13: + if axes_i: + if not _is_value(axes_i): + axes_i = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("ReduceSum", input, axes_i, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i) + return g.op("ReduceSum", input, keepdims_i=keepdims_i, noop_with_empty_axes_i=noop_with_empty_axes_i) + else: + return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) def _interpolate_size_to_scales(g, input, output_size, dim): output_size = _maybe_get_const(output_size, 'is') @@ -371,7 +393,7 @@ def _interpolate_get_scales(g, scale_factor, dim): if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor_rank is not None and scale_factor_rank > 0): return g.op("Concat", offsets, scale_factor, axis_i=0) else: - scale_factor = _unsqueeze_helper(g, scale_factor, 0) + scale_factor = _unsqueeze_helper(g, scale_factor, [0]) scale_factor = g.op("Cast", scale_factor, to_i=cast_pytorch_to_onnx["Float"]) scales = [scale_factor for i in range(dim - 2)] scale_factor = g.op("Concat", offsets, *scales, axis_i=0) @@ -400,7 +422,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_ if not _is_packed_list(size): is_scalar = ((_maybe_get_const(size, 't').dim() == 0)) if is_scalar: - size = _unsqueeze_helper(g, size, 0) + size = _unsqueeze_helper(g, size, [0]) size = [size for i in range(dim - 2)] size = g.op("Concat", *size, axis_i=0) scale_factor = _interpolate_size_to_scales(g, input, size, dim) @@ -409,6 +431,126 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_ return scale_factor, mode +def _interpolate_helper(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) + align_corners = _maybe_get_scalar(align_corners) + coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ + else "align_corners" if align_corners else "pytorch_half_pixel" + + if scales is None: + input_size = g.op("Shape", input) + input_size_beg = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + output_size = g.op("Cast", output_size, to_i=cast_pytorch_to_onnx['Long']) + output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) + + if _export_onnx_opset_version >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + return g.op("Resize", + input, + empty_roi, + empty_scales, + output_size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor") # only valid when mode="nearest" + else: + if _export_onnx_opset_version >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + return g.op("Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor") # only valid when mode="nearest" + return symbolic_fn + + +def __interpolate_helper(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor): + mode = _maybe_get_const(mode, 's') + if 'linear' in mode: + mode = 'linear' + if 'cubic' in mode: + mode = 'cubic' + align_corners = _maybe_get_const(align_corners, 'b') + align_corners = False if not isinstance(align_corners, bool) else align_corners + coordinate_transformation_mode = "asymmetric" if mode == "nearest" \ + else "align_corners" if align_corners else "pytorch_half_pixel" + + if not _is_none(size) : + input_size = g.op("Shape", input) + input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + # in some cases size is not a packed list but size is a scalar + # We need to also verify that (_maybe_get_const(size, 't').dim() == 0) + # but this information is not always available. Try to get the dim, + # and if not assume that it is not a scalar. + try: + is_scalar = not _is_packed_list(size) and ((_maybe_get_const(size, 't').dim() == 0)) + except AttributeError: + is_scalar = not _is_packed_list(size) + if not is_scalar: + warnings.warn("Cannot verify if the output_size is a scalar " + "while exporting interpolate. Assuming that it is not a scalar.") + + if is_scalar: + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with a scalar output_size)", + "missing input shape (try giving an array of output_size values)") + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(rank - 2)] + size = g.op("Concat", *size, axis_i=0) + size = g.op("Cast", size, to_i=cast_pytorch_to_onnx['Long']) + size = g.op("Concat", input_size, size, axis_i=0) + + if _export_onnx_opset_version >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + return g.op("Resize", + input, + empty_roi, + empty_scales, + size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor") + else: # if not _is_none(scales) + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with scales)", "missing input shape") + + if _export_onnx_opset_version >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + scales = _interpolate_get_scales(g, scale_factor, rank) + return g.op("Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor") # only valid when mode="nearest" + + def _unbind_helper(g, self, dim, _outputs): if _export_onnx_opset_version <= 9: from torch.onnx.symbolic_opset9 import unbind @@ -477,9 +619,9 @@ def _index_fill_reshape_helper(g, self, dim, index): return _unimplemented("index_fill", "input rank not accesible") self_dim = self.type().dim() dim_value = _parse_arg(dim, 'i') - unsqueezed_index = g.op("Unsqueeze", index, axes_i=[i for i in range(self_dim) if i != dim_value]) + unsqueezed_index = _unsqueeze_helper(g, index, [i for i in range(self_dim) if i != dim_value]) expanded_index_shape = scatter(g, g.op("Shape", self), 0, - g.op("Unsqueeze", dim, axes_i=[0]), g.op("Shape", index)) + _unsqueeze_helper(g, dim, [0]), g.op("Shape", index)) expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) return expanded_index_shape, expanded_index @@ -525,6 +667,12 @@ def _is_split_static(split_size_or_sizes, _outputs): return False return True +def _optional_input_placeholder_tensor(g): + n = g.op("prim::Constant") + n.setType(OptionalType.ofTensor()) + return n + + # --------------------------------------------------------------------- # ONNX operator version # --------------------------------------------------------------------- diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 6558df6e3d4c..b7f0bb6167b2 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -136,11 +136,11 @@ def __interpolate(g, input, size, scale_factor, mode , align_corners, recompute_ def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if dynamic_slice: - starts = g.op("Unsqueeze", starts, axes_i=[0]) - ends = g.op("Unsqueeze", ends, axes_i=[0]) + starts = sym_help._unsqueeze_helper(g, starts, [0]) + ends = sym_help._unsqueeze_helper(g, ends, [0]) if isinstance(axes, int): axes = g.op("Constant", value_t=torch.tensor(axes)) - axes = g.op("Unsqueeze", axes, axes_i=[0]) + axes = sym_help._unsqueeze_helper(g, axes, [0]) else: assert len(starts) == len(ends) assert len(starts) == len(axes) @@ -220,24 +220,24 @@ def embedding_bag(g, offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) list_ = [] for i in range(offset_len): - start_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), axes_i=[0]) - end_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), axes_i=[0]) + start_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), [0]) + end_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), [0]) axes_ = g.op("Constant", value_t=torch.tensor([0])) indices_row = g.op("Slice", indices, start_, end_, axes_) embeddings = g.op("Gather", embedding_matrix, indices_row) if not sym_help._is_none(per_sample_weights): per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_) - per_sample_weights_row = g.op("Unsqueeze", per_sample_weights_row, axes_i=[1]) + per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1]) embeddings = g.op("Mul", embeddings, per_sample_weights_row) if mode == 0: - embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0) + embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) - embeddings = g.op("Unsqueeze", embeddings, axes_i=[0]) + embeddings = sym_help._unsqueeze_helper(g, embeddings, [0]) list_.append(embeddings) output = g.op("Concat", *list_, axis_i=0) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 85c7bf97c883..3792f77ae377 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -76,7 +76,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ - g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list + sym_help._unsqueeze_helper(g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: @@ -180,7 +180,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False): return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) - index = g.op("Unsqueeze", index, axes_i=[-1]) + index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) @@ -208,38 +208,7 @@ def pixel_shuffle(g, self, upscale_factor): def _interpolate(name, dim, interpolate_mode): - def symbolic_fn(g, input, output_size, *args): - scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args) - align_corners = sym_help._maybe_get_scalar(align_corners) - coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ - else "align_corners" if align_corners else "pytorch_half_pixel" - empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - - if scales is None: - input_size = g.op("Shape", input) - input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) - output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"]) - output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) - scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - return g.op("Resize", - input, - empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize" - scales, # scales is not needed since we are sending out_size - output_size, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s="floor") # only valid when mode="nearest" - else: - return g.op("Resize", - input, - empty_tensor, # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" - scales, # scales is not needed since we are sending out_size - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=interpolate_mode, # nearest, linear, or cubic - nearest_mode_s="floor") # only valid when mode="nearest" - return symbolic_fn + return sym_help._interpolate_helper(name, dim, interpolate_mode) upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest") @@ -252,66 +221,7 @@ def symbolic_fn(g, input, output_size, *args): def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor): - mode = sym_help._maybe_get_const(mode, 's') - if 'linear' in mode: - mode = 'linear' - if 'cubic' in mode: - mode = 'cubic' - align_corners = sym_help._maybe_get_const(align_corners, 'b') - align_corners = False if not isinstance(align_corners, bool) else align_corners - coordinate_transformation_mode = "asymmetric" if mode == "nearest" \ - else "align_corners" if align_corners else "pytorch_half_pixel" - # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" - roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - - if not sym_help._is_none(size) : - input_size = g.op("Shape", input) - input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) - # in some cases size is not a packed list but size is a scalar - # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0) - # but this information is not always available. Try to get the dim, - # and if not assume that it is not a scalar. - try: - is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0)) - except AttributeError: - is_scalar = not sym_help._is_packed_list(size) - if not is_scalar: - warnings.warn("Cannot verify if the output_size is a scalar " - "while exporting interpolate. Assuming that it is not a scalar.") - - if is_scalar: - rank = sym_help._get_tensor_rank(input) - if rank is None: - return sym_help._unimplemented("interpolate (with a scalar output_size)", - "missing input shape (try giving an array of output_size values)") - size = unsqueeze(g, size, 0) - size = [size for i in range(rank - 2)] - size = g.op("Concat", *size, axis_i=0) - size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) - size = g.op("Concat", input_size, size, axis_i=0) - scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) - return g.op("Resize", - input, - roi, - scales, - size, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=mode, # nearest, linear, or cubic - nearest_mode_s="floor") - else: # if not sym_help._is_none(scales) - rank = sym_help._get_tensor_rank(input) - if rank is None: - return sym_help._unimplemented("interpolate (with scales)", "missing input shape") - scales = sym_help._interpolate_get_scales(g, scale_factor, rank) - return g.op("Resize", - input, - roi, - scales, - coordinate_transformation_mode_s=coordinate_transformation_mode, - cubic_coeff_a_f=-0.75, # only valid when mode="cubic" - mode_s=mode, # nearest, linear, or cubic - nearest_mode_s="floor") # only valid when mode="nearest" + return sym_help.__interpolate_helper(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor) @parse_args('v', 'i', 'v', 'v') def gather(g, self, dim, index, sparse_grad=False): @@ -376,7 +286,7 @@ def _len(g, self): if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence": return g.op("SequenceLength", self) sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) - return g.op('Squeeze', sz_0, axes_i=[0]) + return sym_help._squeeze_helper(g, sz_0, [0]) def __getitem_(g, self, i): @@ -489,7 +399,7 @@ def split(g, self, split_size_or_sizes, dim, _outputs=None): return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and len(sym_help._unpack_list(split_size_or_sizes)) == _outputs: - split_sizes = [g.op("Unsqueeze", v, axes_i=[0]) for v in sym_help._unpack_list(split_size_or_sizes)] + split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] @@ -658,7 +568,7 @@ def squeeze(g, self, dim=None): if_node_outputs = g.op("If", cond) if_node = if_node_outputs.node() if_block = torch.onnx.utils._add_block(if_node) - squeeze_ = if_block.op("Squeeze", self, axes_i=[dim]) + squeeze_ = sym_help._squeeze_helper(if_block, self, [dim]) torch.onnx.utils._add_output_to_block(if_block, squeeze_) else_block = torch.onnx.utils._add_block(if_node) identity_ = else_block.op("Identity", self) @@ -673,13 +583,12 @@ def squeeze(g, self, dim=None): "be exported without the squeeze node. If the model is intended to be used with dynamic " + "input shapes, please export with dynamic_axes argument.") return self - return g.op("Squeeze", self, axes_i=[dim]) + return sym_help._squeeze_helper(g, self, [dim]) @parse_args('v', 'i') def unsqueeze(g, self, dim): - return g.op("Unsqueeze", self, axes_i=[dim]) - + return sym_help._unsqueeze_helper(g, self, [dim]) def mm(g, self, other): return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) @@ -782,7 +691,7 @@ def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding # Broadcast and add kernel staring positions (indices) with # kernel_grid along dim d, to get block indices along dim d - blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1] + blocks_d_indices = sym_help._unsqueeze_helper(g, blocks_d_indices, [0]) # Reshape to [1, -1] kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1]))) block_mask = g.op("Add", blocks_d_indices, kernel_mask) @@ -804,8 +713,8 @@ def _get_im2col_output_shape(g, input, kernel_h, kernel_w): g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))) return g.op("Concat", - g.op("Unsqueeze", batch_dim, axes_i=[0]), - g.op("Unsqueeze", channel_unfolded, axes_i=[0]), + sym_help._unsqueeze_helper(g, batch_dim, [0]), + sym_help._unsqueeze_helper(g, channel_unfolded, [0]), g.op("Constant", value_t=torch.tensor([-1])), axis_i=0) @@ -901,9 +810,9 @@ def embedding_bag(g, loop_condition = g.op("Cast", loop_condition, to_i=9) zero = g.op("Constant", value_t=torch.tensor([0])) - indices_len = g.op("Unsqueeze", - sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), - axes_i=[0]) + indices_len = sym_help._unsqueeze_helper(g, + sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + [0]) if not include_last_offset: offsets = [offsets, indices_len] offsets = g.op("Concat", *offsets, axis_i=0) @@ -923,8 +832,8 @@ def embedding_bag(g, indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0) indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0) - indices_start = loop_block.op("Unsqueeze", indices_start, axes_i=[0]) - indices_end = loop_block.op("Unsqueeze", indices_end, axes_i=[0]) + indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0]) + indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0]) indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero) embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0) @@ -933,10 +842,10 @@ def embedding_bag(g, indices_start, indices_end, zero) - per_sample_weights_row = loop_block.op("Unsqueeze", per_sample_weights_row, axes_i=[1]) + per_sample_weights_row = sym_help._unsqueeze_helper(loop_block, per_sample_weights_row, [1]) embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row) if mode == 0: - embeddings = loop_block.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0) + embeddings = sym_help._reducesum_helper(loop_block, embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index cd67fd508fa2..5a926eef5e1d 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -52,6 +52,34 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) +@parse_args('v', 'v', 'v', 'v', 'i') +def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): + from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = sigmoid(g, input) + log_sig_x = log(g, sig_x) + sub_1_x = sub(g, p, sig_x) + sub_1_y = sub(g, p, target) + log_1_x = log(g, sub_1_x) + if pos_weight is None or sym_help._is_none(pos_weight): + output = neg(g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x))) + else: + output = neg(g, add(g, mul(g, mul(g, target, log_sig_x), pos_weight), mul(g, sub_1_y, log_1_x))) + + if weight is not None and not sym_help._is_none(weight): + output = mul(g, weight, output) + + reduction = sym_help._maybe_get_const(reduction, 'i') + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output) + elif reduction == 2: + return g.op("ReduceSum", output) + else: + return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum") + + def celu(g, self, alpha): alpha = sym_help._maybe_get_const(alpha, 'f') # if the input is of type double cast it to float @@ -132,11 +160,11 @@ def unfold(g, input, dimension, size, step): starts = loop_block.op("Gather", low_indices, block_input_iter) ends = loop_block.op("Gather", hi_indices, block_input_iter) axes = loop_block.op("Constant", value_t=torch.tensor([2])) - starts = loop_block.op("Unsqueeze", starts, axes_i=[0]) - ends = loop_block.op("Unsqueeze", ends, axes_i=[0]) + starts = sym_help._unsqueeze_helper(loop_block, starts, [0]) + ends = sym_help._unsqueeze_helper(loop_block, ends, [0]) stack = loop_block.op("Slice", input, starts, ends, axes) - unsqueeze = loop_block.op("Unsqueeze", loop_block.op("Transpose", stack, perm_i=perm), axes_i=[dimension]) + unsqueeze = sym_help._unsqueeze_helper(loop_block, loop_block.op("Transpose", stack, perm_i=perm), [dimension]) unsqueeze_list.append(unsqueeze) concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0) @@ -148,7 +176,7 @@ def unfold(g, input, dimension, size, step): perm = [0, 1, 2, 3, 4] perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] transpose = g.op("Transpose", loop_output, perm_i=perm) - squeeze = g.op("Squeeze", transpose, axes_i=[0]) + squeeze = sym_help._squeeze_helper(g, transpose, [0]) return squeeze else: diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 001a20147c4f..9fffa23a1131 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -2,15 +2,16 @@ # see Note [Edit Symbolic Files] in symbolic_helper.py # This file exports ONNX ops for opset 13 -from torch.onnx.symbolic_helper import _block_list_in_opset import torch import torch.onnx.symbolic_helper as sym_help -from torch.onnx.symbolic_helper import parse_args +from torch.onnx.symbolic_helper import parse_args, _unimplemented +from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input -block_listed_operators = ['embedding_bag'] -for block_listed_op in block_listed_operators: - vars()[block_listed_op] = _block_list_in_opset(block_listed_op) +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +# This file exports ONNX ops for opset 13 @parse_args('v', 'i', 'none') @@ -38,7 +39,7 @@ def frobenius_norm(g, self, dim=None, keepdim=False): if not sym_help._is_value(dim_val) and len(dim_val) == 0: return g.op("ReduceL2", self, keepdims_i=0) sqr = g.op('Mul', self, self) - sumsqr = g.op('ReduceSum', sqr, dim, keepdims_i=keepdim) + sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) return g.op('Sqrt', sumsqr) @@ -108,3 +109,36 @@ def unbind(g, self, dim=0, _outputs=None): def glu(g, input, dim): first, second = g.op('Split', input, dim, outputs=2) return g.op('Mul', first, g.op('Sigmoid', second)) + + +def _reduce_op_symbolic(onnx_op_name): + def symbolic(g, self, dim=None, keepdim=None): + self = _maybe_cast_reduce_op_input(g, self) + if dim is None: + # all-reduce path + return g.op(onnx_op_name, self, keepdims_i=0) + else: + keepdim = sym_help._get_const(keepdim, 'i', 'keepdim') + return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) + return symbolic + +def _reduce_with_dtype(onnx_op, name): + symbolic = _reduce_op_symbolic(onnx_op) + + @overload_by_arg_count + def reduce(g, *args, **kwargs): + @parse_args('v', 'none') + def reduce_nodim(g, self, dtype): + if dtype.node().kind() != 'prim::Constant': + return _unimplemented(name, "dtype") + return symbolic(g, self) + + @parse_args('v', 'v', 'i', 'none') + def reduce_dim(g, self, dim, keepdim, dtype): + if dtype.node().kind() != 'prim::Constant': + return _unimplemented(name, "dtype") + return symbolic(g, self, dim, keepdim) + return reduce_nodim, reduce_dim + return reduce + +sum = _reduce_with_dtype('ReduceSum', 'sum') diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index ada731884f76..043fbd041897 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -186,7 +186,7 @@ def cat(g, tensor_list, dim): @parse_args('v', 'i') def stack(g, tensor_list, dim): - unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)] + unsqueezed = [sym_help._unsqueeze_helper(g, t, [dim]) for t in sym_help._unpack_list(tensor_list)] return g.op("Concat", *unsqueezed, axis_i=dim) @@ -592,7 +592,7 @@ def unbind(g, self, dim=0, _outputs=None): outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) outputs = [outputs] if _outputs == 1 else outputs - squeezed_outputs = [g.op("Squeeze", out, axes_i=[dim]) for out in outputs] + squeezed_outputs = [sym_help._squeeze_helper(g, out, [dim]) for out in outputs] return squeezed_outputs @@ -605,7 +605,7 @@ def select(g, self, dim, index): else: end_index = index + 1 slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index]) - return g.op("Squeeze", slice_node, axes_i=[dim]) + return sym_help._squeeze_helper(g, slice_node, [dim]) else: return g.op("Gather", self, index, axis_i=dim) @@ -640,7 +640,7 @@ def squeeze(g, self, dim=None): "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + "non-singleton dimensions, it is recommended to export this model using opset " + "version 11 or higher.") - return g.op("Squeeze", self, axes_i=[squeeze_dim]) + return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim]) if dim_size > 1: warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " + "this dimension in the given input is " + str(dim_size) + ". The model will " + @@ -651,12 +651,12 @@ def squeeze(g, self, dim=None): warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " + "intended to be used with dynamic input shapes, please use opset version 11 to export the model.") - return g.op("Squeeze", self, axes_i=[squeeze_dim]) + return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim]) def prelu(g, self, weight): self_rank = sym_help._get_tensor_rank(self) if self_rank is not None and self_rank > 2: - weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + weight = sym_help._unsqueeze_helper(g, weight, list(range(1, self_rank - 1))) return g.op("PRelu", self, weight) @@ -674,7 +674,7 @@ def floor(g, input): def _len(g, self): sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) - return g.op('Squeeze', sz_0, axes_i=[0]) + return sym_help._squeeze_helper(g, sz_0, [0]) @parse_args('v', 't', 't') @@ -753,7 +753,7 @@ def softmax(g, input, dim, dtype=None): input = g.op('Sub', input, g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1)) exp = g.op('Exp', input) - sum = g.op('ReduceSum', exp, axes_i=[dim]) + sum = sym_help._reducesum_helper(g, exp, axes_i=[dim]) softmax = g.op('Div', exp, sum) if dtype and dtype.node().kind() != 'prim::Constant': parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype') @@ -1105,6 +1105,21 @@ def __or_(g, input, other): return g.op('Or', input, other) +@wrap_logical_op_with_cast_to_and_from('Bool') +def logical_and(g, input, other): + return g.op('And', input, other) + + +@wrap_logical_op_with_cast_to_and_from('Bool') +def logical_or(g, input, other): + return g.op('Or', input, other) + + +@wrap_logical_op_with_cast_to_and_from('Bool') +def logical_xor(g, input, other): + return g.op('Xor', input, other) + + def __rshift_(g, self, other): # make sure to cast other to self's type # (when self is long, make sure that other is not float) @@ -1356,7 +1371,7 @@ def unfold(g, input, dimension, size, step): ndim = len(sizes) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) - unsqueeze = [g.op("Unsqueeze", g.op("Transpose", t, perm_i=perm), axes_i=[dimension]) for t in stack] + unsqueeze = [sym_help._unsqueeze_helper(g, g.op("Transpose", t, perm_i=perm), [dimension]) for t in stack] return g.op("Concat", *unsqueeze, axis_i=dimension) else: return _unimplemented("Unfold", "input size not accessible") @@ -1732,14 +1747,14 @@ def eye(g, *args): if len(args) == 5: # aten::eye(n, dtype, layout, device, pin_memory) n, dtype, layout, device, pin_memory = args - dim_size = g.op("Unsqueeze", n, axes_i=[0]) + dim_size = sym_help._unsqueeze_helper(g, n, [0]) shape = g.op("Concat", dim_size, dim_size, axis_i=0) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) elif len(args) == 6: # aten::eye(n, m, dtype, layout, device, pin_memory) n, m, dtype, layout, device, pin_memory = args - shape = g.op("Concat", g.op("Unsqueeze", n, axes_i=[0]), g.op("Unsqueeze", m, axes_i=[0]), axis_i=0) + shape = g.op("Concat", sym_help._unsqueeze_helper(g, n, [0]), sym_help._unsqueeze_helper(g, m, [0]), axis_i=0) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) else: @@ -1760,9 +1775,9 @@ def slice(g, self, *args): 'is a deprecated experimental op. Please use statically allocated ' 'variables or export to a higher opset version.') else: - start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0]) - end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0]) - dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0]) + start_unsqueezed = sym_help._unsqueeze_helper(g, start, [0]) + end_unsqueezed = sym_help._unsqueeze_helper(g, end, [0]) + dim_unsqueezed = sym_help._unsqueeze_helper(g, dim, [0]) return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed) else: start = _parse_arg(start, 'i') @@ -1814,7 +1829,7 @@ def unsqueeze(g, self, dim): else: return _unimplemented('unsqueeze', 'negative axis with unknown input rank') - return g.op("Unsqueeze", self, axes_i=[dim]) + return sym_help._unsqueeze_helper(g, self, axes_i=[dim]) @parse_args('v', 'i', 'i', 'none') @@ -1973,7 +1988,7 @@ def transform_weights_no_bias(layer_index): elif variant == 'GRU' or variant == 'LSTM': weight_ih, weight_hh = \ [reform_weights(g, w, hidden_size, reform_permutation) for w in weights] - return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh)) + return tuple(sym_help._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh)) def transform_weights(layer_index): weights = layer_weights[layer_index] @@ -1983,7 +1998,7 @@ def transform_weights(layer_index): weight_ih, weight_hh, bias_ih, bias_hh = \ [reform_weights(g, w, hidden_size, reform_permutation) for w in weights] bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0) - return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat)) + return tuple(sym_help._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh, bias_concat)) def retrieve_state(x, start, end): return x if num_layers == 1 else sym_help._slice_helper(g, x, axes=[0], starts=[start], ends=[end]) @@ -2050,7 +2065,7 @@ def retrieve_state(x, start, end): prev_output = g.op('Transpose', prev_output, perm_i=[0, 2, 1, 3]) prev_output = g.op('Reshape', prev_output, g.op('Constant', value_t=torch.LongTensor([0, 0, -1]))) else: - prev_output = g.op('Squeeze', prev_output, axes_i=[1]) + prev_output = sym_help._squeeze_helper(g, prev_output, [1]) h_outs.append(h_out) if variant == 'LSTM': @@ -2382,8 +2397,8 @@ def gather(g, self, dim, index, sparse_grad=False): values = g.op("Constant", value_t=torch.LongTensor([0, 1])) depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) index = g.op("Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=sym_help.cast_pytorch_to_onnx[dtype]) - mul = g.op("Mul", g.op("Unsqueeze", self, axes_i=[dim + 1]), index) - return g.op("ReduceSum", mul, axes_i=[dim], keepdims_i=0) + mul = g.op("Mul", sym_help._unsqueeze_helper(g, self, [dim + 1]), index) + return sym_help._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) @parse_args('v', 'is', 'b', 'i') @@ -2477,42 +2492,42 @@ def _get_arange_dtype(dtype): if len(args) == 2: # aten::arange(Scalar end, Tensor out) - end = g.op("Unsqueeze", args[0], axes_i=[0]) + end = sym_help._unsqueeze_helper(g, args[0], [0]) dtype = 4 # default to int64 - arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, None, None)), axes_i=[1]) + arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, end, dtype, None, None)), [1]) return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) elif len(args) == 4: # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) dtype = 4 # default to int64 - step = g.op("Unsqueeze", args[2], axes_i=[0]) - end = g.op("Unsqueeze", args[1], axes_i=[0]) - start = g.op("Unsqueeze", args[0], axes_i=[0]) + step = sym_help._unsqueeze_helper(g, args[2], [0]) + end = sym_help._unsqueeze_helper(g, args[1], [0]) + start = sym_help._unsqueeze_helper(g, args[0], [0]) range_tensor = g.op("Div", g.op("Sub", end, start), step) - arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, None, None, None)), axes_i=[1]) + arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, None, None, None)), [1]) arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) elif len(args) == 5: # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[1]) - end = g.op("Unsqueeze", args[0], axes_i=[0]) - arange_tensor = g.op("Squeeze", nonzero(g, ones(g, end, dtype, *(args[2:]))), axes_i=[1]) + end = sym_help._unsqueeze_helper(g, args[0], [0]) + arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, end, dtype, *(args[2:]))), [1]) return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) elif len(args) == 6: # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[2]) - end = g.op("Unsqueeze", args[1], axes_i=[0]) - start = g.op("Unsqueeze", args[0], axes_i=[0]) + end = sym_help._unsqueeze_helper(g, args[1], [0]) + start = sym_help._unsqueeze_helper(g, args[0], [0]) range_tensor = g.op("Sub", end, start) - arange_tensor = g.op("Add", g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), axes_i=[1]), start) + arange_tensor = g.op("Add", sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]), start) return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) elif len(args) == 7: # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) dtype = _get_arange_dtype(args[3]) - step = g.op("Unsqueeze", args[2], axes_i=[0]) - end = g.op("Unsqueeze", args[1], axes_i=[0]) - start = g.op("Unsqueeze", args[0], axes_i=[0]) + step = sym_help._unsqueeze_helper(g, args[2], [0]) + end = sym_help._unsqueeze_helper(g, args[1], [0]) + start = sym_help._unsqueeze_helper(g, args[0], [0]) range_tensor = g.op("Div", g.op("Sub", end, start), step) - arange_tensor = g.op("Squeeze", nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), axes_i=[1]) + arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, dtype, *(args[4:]))), [1]) arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype]) else: @@ -2541,7 +2556,7 @@ def try_mask_to_index(index): warnings.warn("Exporting aten::index operator with indices of type Byte. " "Only 1-D indices are supported. In any other case, " "this will produce an incorrect ONNX graph.") - index = squeeze(g, nonzero(g, index), dim=1) + index = sym_help._squeeze_helper(g, nonzero(g, index), [1]) return index indices = [try_mask_to_index(idx) for idx in indices] @@ -2639,7 +2654,7 @@ def try_mask_to_index(index): @parse_args('v', 'is', 'i') def frobenius_norm(g, self, dim=None, keepdim=False): sqr = g.op('Mul', self, self) - sumsqr = g.op('ReduceSum', sqr, axes_i=dim, keepdims_i=keepdim) + sumsqr = sym_help._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) return g.op('Sqrt', sumsqr) @@ -2687,10 +2702,9 @@ def remainder(g, input, other): def gelu(g, self): _sqrt2 = 1.4142135623730951 - erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2))) - erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.float))) - return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.float))) - + erf = g.op('Erf', g.op('Div', self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.double))) + return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.double))) @parse_args('v', 'i', 'v', 'v', 'f', 'i') def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): @@ -2730,7 +2744,7 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] axes = list(range(1, input_rank - 1)) - return add(g, mul(g, norm, g.op("Unsqueeze", weight, axes_i=axes)), g.op("Unsqueeze", bias, axes_i=axes)) + return add(g, mul(g, norm, sym_help._unsqueeze_helper(g, weight, axes)), sym_help._unsqueeze_helper(g, bias, axes)) @parse_args('v', 'v', 'i') @@ -2805,7 +2819,7 @@ def kl_div(g, input, target, reduction, log_target): elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: - return g.op("ReduceSum", output, keepdims_i=0) + return sym_help._reducesum_helper(g, output, keepdims_i=0) else: return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to " "request ONNX export support for the missing reduction type.") diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 59d45de1f553..a17f2ea2eb2d 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -29,6 +29,8 @@ def is_in_onnx_export(): global __IN_ONNX_EXPORT return __IN_ONNX_EXPORT +# Skip check due to cannot import IValue from torch._C +_params_dict = {} # type: ignore @contextlib.contextmanager def select_model_mode_for_export(model, mode): @@ -207,7 +209,8 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_onnx_scalar_type_analysis(graph) torch._C._jit_pass_lint(graph) - torch._C._jit_pass_onnx_fold_if(graph) + if dynamic_axes is None or not bool(dynamic_axes): + torch._C._jit_pass_onnx_fold_if(graph) from torch.onnx.symbolic_helper import _export_onnx_opset_version torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size) @@ -224,7 +227,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_lint(graph) from torch.onnx.symbolic_helper import _onnx_shape_inference, _export_onnx_opset_version if _onnx_shape_inference: - torch._C._jit_pass_onnx_graph_shape_type_inference(graph, _export_onnx_opset_version) + torch._C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, _export_onnx_opset_version) return graph @@ -358,7 +361,7 @@ def _trace(func, args, operator_export_type, return_outs=False): torch.jit._get_trace_graph(func, args, strict=False, _force_outplace=False, _return_inputs_states=True) warn_on_static_input_change(inputs_states) - trace_graph = _optimize_graph(trace_graph, operator_export_type) + trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) if return_outs: return trace_graph, torch_out return trace_graph @@ -422,6 +425,11 @@ def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): torch._C._jit_pass_onnx_function_substitution(graph) return graph, params, torch_out +def _get_named_param_dict(graph, params): + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params):] + _params_dict = dict(zip(param_names, params)) + return _params_dict def _model_to_graph(model, args, verbose=False, input_names=None, output_names=None, @@ -443,9 +451,7 @@ def _model_to_graph(model, args, verbose=False, _retain_param_name, use_new_jit_passes) - input_and_param_names = [val.debugName() for val in graph.inputs()] - param_names = input_and_param_names[len(input_and_param_names) - len(params):] - params_dict = dict(zip(param_names, params)) + params_dict = _get_named_param_dict(graph, params) graph = _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=_disable_torch_constant_prop, @@ -479,9 +485,7 @@ def _model_to_graph(model, args, verbose=False, flatten_args, _ = torch._C._jit_flatten(args) assert len(params) + len(flatten_args) == sum(1 for _ in graph.inputs()) - input_and_param_names = [val.debugName() for val in graph.inputs()] - param_names = input_and_param_names[len(input_and_param_names) - len(params):] - params_dict = dict(zip(param_names, params)) + params_dict = _get_named_param_dict(graph, params) if training is None or training == TrainingMode.EVAL: params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) @@ -491,6 +495,9 @@ def _model_to_graph(model, args, verbose=False, _export_onnx_opset_version) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + if _onnx_shape_inference: + torch._C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, _export_onnx_opset_version) + params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) # For ONNX opset < 9, constants only have three data types: float16, float, double. @@ -878,7 +885,7 @@ def const_if_tensor(arg): from torch.onnx.symbolic_helper import _onnx_shape_inference if _onnx_shape_inference: from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version - torch._C._jit_pass_onnx_node_shape_type_inference(n, opset_version) + torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version) if outputs == 1: return n.output() @@ -1032,7 +1039,7 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor # Process Loop and If after subblock is converted. from torch.onnx.symbolic_helper import _onnx_shape_inference if _onnx_shape_inference: - torch._C._jit_pass_onnx_node_shape_type_inference(new_node, opset_version) + torch._C._jit_pass_onnx_node_shape_type_inference(new_node, _params_dict, opset_version) return new_op_outputs else: symbolic_name = 'prim_' + op_name diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index dabbf91dff90..e0f568d7cc4b 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -10,3 +10,4 @@ ''' from .profiler import profile, schedule, ProfilerAction, ProfilerActivity +from torch.autograd import kineto_available diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 25bee1c2019f..8024bb727a43 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -86,7 +86,7 @@ class profile(object): print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) - Usimg the profiler's ``schedule``, ``on_trace_ready`` and ``next_step`` functions: + Usimg the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: .. code-block:: python @@ -96,7 +96,7 @@ class profile(object): def trace_handler(prof): print(prof.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) - # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step()) + ".json") + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") with torch.profiler.profile( activities=[ @@ -120,7 +120,7 @@ def trace_handler(prof): for iter in range(N): code_iteration_to_profile(iter) # send a signal to the profiler that the next iteration has started - p.next_step() + p.step() """ def __init__( self, @@ -172,7 +172,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.step_rec_fn.__exit__(None, None, None) self._exit_actions() - def next_step(self): + def step(self): """ Signals the profiler that the next profiling step has started. """ @@ -232,12 +232,6 @@ def next_step(self): self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) self.step_rec_fn.__enter__() - def step(self): - """ - Returns the current profiling step. - """ - return self.step_num - def export_chrome_trace(self, path: str): """ Exports the collected trace in Chrome JSON format. diff --git a/torch/quantization/_numeric_suite.py b/torch/quantization/_numeric_suite.py index 100ff54d4436..de0b4083b390 100644 --- a/torch/quantization/_numeric_suite.py +++ b/torch/quantization/_numeric_suite.py @@ -1,10 +1,9 @@ - import torch import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd from torch.quantization import prepare -from typing import Dict +from typing import Dict, List, Optional, Any, Union, Callable, Set from .quantization_mappings import ( get_default_compare_output_module_list, @@ -18,7 +17,10 @@ } -def _find_match(str_list, key_str, postfix): +def _find_match( + str_list: Union[Dict[str, Any], List[str]], key_str: str, + postfix: str, +) -> Optional[str]: split_str = key_str.split(".") if split_str[-1] == postfix: match_string = "".join(key_str.split(".")[0:-1]) @@ -42,11 +44,14 @@ def _find_match(str_list, key_str, postfix): return s2 if match_string == pattern2: return s2 + return None else: return None -def compare_weights(float_dict, quantized_dict): +def compare_weights( + float_dict: Dict[str, Any], quantized_dict: Dict[str, Any] +) -> Dict[str, Dict[str, torch.Tensor]]: r"""Compare the weights of the float module with its corresponding quantized module. Return a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and @@ -105,7 +110,10 @@ def compare_weights(float_dict, quantized_dict): return weight_dict -def _get_logger_dict_helper(mod, target_dict, prefix=""): +def _get_logger_dict_helper( + mod: nn.Module, target_dict: Dict[str, Any], + prefix: str = "", +) -> None: r"""This is the helper function for get_logger_dict Args: @@ -127,7 +135,7 @@ def get_prefix(prefix): _get_logger_dict_helper(child, target_dict, module_prefix) -def get_logger_dict(mod, prefix=""): +def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]: r"""Traverse the modules and save all logger stats into target dict. This is mainly used for quantization accuracy debug. @@ -195,11 +203,11 @@ def forward(self, x): return x -def _convert_tuple_to_list(t): +def _convert_tuple_to_list(t: Any) -> Any: return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t -def _dequantize_tensor_list(t): +def _dequantize_tensor_list(t: Any) -> Any: return ( list(_dequantize_tensor_list(x) for x in t) if type(t) is list @@ -228,7 +236,7 @@ def __init__(self, q_module, float_module, Logger): self.dequant = nnq.DeQuantize() self.logger = Logger() - def forward(self, *x): + def forward(self, *x) -> torch.Tensor: xl = _convert_tuple_to_list(x) output = self.orig_module(*xl) xl_float = _dequantize_tensor_list(xl) @@ -236,7 +244,7 @@ def forward(self, *x): self.logger(output, shadow_output) return output - def add(self, x, y): + def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = self.orig_module.add(x, y) x = x.dequantize() y = y.dequantize() @@ -244,14 +252,14 @@ def add(self, x, y): self.logger(output, shadow_output) return output - def add_scalar(self, x, y): + def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: output = self.orig_module.add_scalar(x, y) x = x.dequantize() shadow_output = self.shadow_module.add_scalar(x, y) self.logger(output, shadow_output) return output - def mul(self, x, y): + def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = self.orig_module.mul(x, y) x = x.dequantize() y = y.dequantize() @@ -259,21 +267,21 @@ def mul(self, x, y): self.logger(output, shadow_output) return output - def mul_scalar(self, x, y): + def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: output = self.orig_module.mul_scalar(x, y) x = x.dequantize() shadow_output = self.shadow_module.mul_scalar(x, y) self.logger(output, shadow_output) return output - def cat(self, x, dim=0): + def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor: output = self.orig_module.cat(x, dim) x = [y.dequantize() for y in x] shadow_output = self.shadow_module.cat(x, dim) self.logger(output, shadow_output) return output - def add_relu(self, x, y): + def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = self.orig_module.add_relu(x, y) x = x.dequantize() y = y.dequantize() @@ -282,7 +290,10 @@ def add_relu(self, x, y): return output -def prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger): +def prepare_model_with_stubs( + float_module: nn.Module, q_module: nn.Module, + module_swap_list: Set[type], Logger: Callable, +) -> None: r"""Prepare the model by attaching the float module to its matching quantized module as the shadow if the float module type is in module_swap_list. @@ -322,8 +333,9 @@ def prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger): def compare_model_stub( - float_model, q_model, module_swap_list, *data, Logger=ShadowLogger -): + float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type], + *data, Logger=ShadowLogger +) -> Dict[str, Dict]: r"""Compare quantized module in a model with its floating point counterpart, feeding both of them the same input. Return a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and @@ -361,7 +373,9 @@ def compare_model_stub( return ob_dict -def get_matching_activations(float_module, q_module): +def get_matching_activations( + float_module: nn.Module, q_module: nn.Module, +) -> Dict[str, Dict[str, torch.Tensor]]: r"""Find the matching activation between float and quantized modules. Args: @@ -387,11 +401,11 @@ def get_matching_activations(float_module, q_module): def prepare_model_outputs( - float_module, - q_module, + float_module: nn.Module, + q_module: nn.Module, Logger=OutputLogger, allow_list=None -): +) -> None: r"""Prepare the model by attaching the logger to both float module and quantized module if they are in the allow_list. @@ -406,9 +420,9 @@ def prepare_model_outputs( allow_list = get_default_compare_output_module_list() qconfig_debug = torch.quantization.QConfig(activation=Logger, weight=None) - float_module.qconfig = qconfig_debug + float_module.qconfig = qconfig_debug # type: ignore prepare(float_module, inplace=True, allow_list=allow_list) - q_module.qconfig = qconfig_debug + q_module.qconfig = qconfig_debug # type: ignore prepare( q_module, inplace=True, @@ -418,12 +432,12 @@ def prepare_model_outputs( def compare_model_outputs( - float_model, - q_model, + float_model: nn.Module, + q_model: nn.Module, *data, Logger=OutputLogger, allow_list=None -): +) -> Dict[str, Dict[str, torch.Tensor]]: r"""Compare output activations between float and quantized models at corresponding locations for the same input. Return a dict with key corresponding to quantized module names and each entry being a dictionary with two keys diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index e6ac74dbf903..06f15240e761 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -218,12 +218,35 @@ def __init__(self, quantizer: QuantizerCls, node: Node): def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, debug: bool = False, convert_custom_config_dict: Dict[str, Any] = None) -> Node: + # Supported combinations are: + # quant_type | activation (compute_type) | weight + # static quint8 qint8 + + # tuple (activation_dtype, weight_dtype, compute_dtype) + supported_dtypes = [ + (torch.quint8, torch.qint8, None), + ] + # TODO: debug option for conv module qconfig = quantizer.qconfig_map[node.name] + dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported + if dtypes not in supported_dtypes: + warnings.warn( + "dtype combination: {} is not " + "supported by Conv " + "supported dtype combinations are: {}".format(dtypes, supported_dtypes)) + if self.relu_node: + conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False)) + relu_args = [conv_out] + relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:])) + relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs) + return quantizer.quantized_graph.create_node( + "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) + else: + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + activation_statically_quantized = activation_is_statically_quantized(qconfig) - # only static qunatization (for both ptq and qat) is supported for conv - if not activation_statically_quantized: - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step @@ -246,21 +269,32 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, (load_arg(quantized=True)(self.conv_node.args[0]),), {}) else: # call_function - assert self.conv_node.op == 'call_function' - if self.relu_node is not None: - raise Exception("functional conv + relu is not supported yet") + assert self.conv_node.op == "call_function" if debug: args = load_arg(quantized=[0, 1])(self.conv_node.args) args = load_arg(quantized=False)(self.conv_node.args) kwargs = load_arg(quantized=False)(self.conv_node.kwargs) - conv_out = quantizer.quantized_graph.create_node( - 'call_function', torch.nn.functional.conv2d, args, kwargs) - root_module = quantizer.modules[''] - return quantize_node( - root_module, quantizer.quantized_graph, conv_out, quantizer.activation_post_process_map[self.conv_node.name]) + op_out = quantizer.quantized_graph.create_node( + "call_function", torch.nn.functional.conv2d, args, kwargs) + if self.relu_node: + relu_args = [op_out] + relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:])) + relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs) + op_out = quantizer.quantized_graph.create_node( + "call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs) + + if activation_statically_quantized: + root_module = quantizer.modules[''] + act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name + return quantize_node( + root_module, quantizer.quantized_graph, op_out, + quantizer.activation_post_process_map[act_post_process_name]) + else: + # output for dynamically quantized conv op is not quantized + return op_out else: - assert len(self.conv_node.args) == 7, \ - 'only conv2d calls with all arguments specified is support right now in debug=False option' + assert len(self.conv_node.args) >= 7, \ + "only conv2d calls with all arguments specified is supported right now in debug=False option" args = load_arg(quantized=[0, 1])(self.conv_node.args) # pack weight weight = load_arg(quantized=True)(self.conv_node.args[1]) @@ -268,14 +302,23 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, prepack_args = tuple([weight] + list(other_args)) packed_weight = quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {}) + assert activation_statically_quantized, \ + "currently only static quantization is supported for conv" # construct conv input - conv_input = load_arg(quantized=True)(self.conv_node.args[0]) - activation_post_process = quantizer.activation_post_process_map[self.conv_node.name] - scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) - qconv_args = (conv_input, packed_weight, scale, zero_point) - kwargs = load_arg(quantized=False)(self.conv_node.kwargs) - return quantizer.quantized_graph.create_node( - 'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs) + if activation_statically_quantized: + qconv_op = torch.ops.quantized.conv2d_relu if self.relu_node else torch.ops.quantized.conv2d + conv_input = load_arg(quantized=True)(self.conv_node.args[0]) + act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name + activation_post_process = quantizer.activation_post_process_map[act_post_process_name] + scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) + qconv_args = (conv_input, packed_weight, scale, zero_point) + kwargs = load_arg(quantized=False)(self.conv_node.kwargs) + return quantizer.quantized_graph.create_node( + 'call_function', qconv_op, qconv_args, kwargs) + else: + # conv2d_dyanmic branch + raise Exception("Only static quant is supported for conv") + # handle linear, maybe followed by relu @register_quant_pattern(torch.nn.Linear) @@ -316,6 +359,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " @@ -412,9 +456,9 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, prepack_op = get_linear_prepack_op_for_dtype(weight_dtype(qconfig)) packed_weight = quantizer.quantized_graph.create_node( 'call_function', prepack_op, prepack_args, {}) - qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear # construct linear input if activation_statically_quantized: + qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear linear_input = load_arg(quantized=True)(self.linear_node.args[0]) act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name activation_post_process = \ @@ -484,6 +528,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, emb_node = node qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " @@ -527,6 +572,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, assert node.op == 'call_module' qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py index 8c2b407beea1..a93e13b665be 100644 --- a/torch/testing/_internal/common_jit.py +++ b/torch/testing/_internal/common_jit.py @@ -36,7 +36,7 @@ def check_output_types(self, func, ref_outputs, args, kwargs): 'grid_sample', ]) -def check_against_reference(self, func, reference_func, args, kwargs=None, +def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, allow_unused=True, check_types=True, no_grad=False): kwargs = kwargs if kwargs else {} @@ -72,10 +72,10 @@ def clone_inputs(requires_grad): with enable_profiling_mode_for_profiling_tests(): # test single grad case - outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) grads = torch.autograd.grad(allSum(outputs), recording_tensors, allow_unused=allow_unused) - outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, allow_unused=allow_unused) self.assertEqual(outputs, outputs_test) @@ -84,7 +84,7 @@ def clone_inputs(requires_grad): if self._testMethodName in nn_functional_single_grad: return - outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs) + outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) l1 = allSum(outputs) grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=allow_unused) @@ -92,7 +92,7 @@ def clone_inputs(requires_grad): l2 = (allSum(grads) * l1) grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) recording_inputs, recording_tensors = clone_inputs(True) - outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs) + outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) l1_test = allSum(outputs_test) grads_test = torch.autograd.grad( l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 90aa1468bd4a..324a9346c45b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -845,6 +845,21 @@ def sample_inputs_linalg_solve(op_info, device, dtype, requires_grad=False): return out +def sample_inputs_std_var(op_info, device, dtype, requires_grad): + tensor_nd = make_tensor((S, S, S), device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + tensor_1d = make_tensor((S,), device=device, dtype=dtype, + low=None, high=None, requires_grad=requires_grad) + + return [ + SampleInput(tensor_nd), + SampleInput(tensor_nd, kwargs=dict(dim=1)), + SampleInput(tensor_nd, kwargs=dict(dim=1, unbiased=True, keepdim=True)), + SampleInput(tensor_1d, kwargs=dict(dim=0, unbiased=True, keepdim=True)), + SampleInput(tensor_1d, kwargs=dict(dim=0, unbiased=False, keepdim=False)), + ] + + def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_svd=False): """ This function generates input for torch.svd with distinct singular values so that autograd is always stable. @@ -1311,11 +1326,7 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): supports_tensor_out=False, sample_inputs_func=sample_inputs_slogdet, output_func=itemgetter(1), - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], - skips=( - # These tests do not work with output_func=itemgetter(1) - # TODO: remove this once https://github.com/pytorch/pytorch/issues/49326 is resolved - SkipInfo('TestCommon', 'test_variant_consistency_jit'),)), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), UnaryUfuncInfo('log', ref=np.log, domain=(0, float('inf')), @@ -1437,6 +1448,18 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): SkipInfo('TestCommon', 'test_variant_consistency_jit', device_type='cuda', dtypes=[torch.float16]), )), + OpInfo('std', + dtypes=floating_types_and(), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + supports_tensor_out=False, + test_complex_grad=False, + test_inplace_grad=False, + # std has only partial support for complex and half (#51127) + skips=(SkipInfo('TestOpInfo', 'test_unsupported_dtypes', + dtypes=[torch.half, torch.complex64, torch.complex128]),), + assert_autodiffed=True, + ), UnaryUfuncInfo('tan', ref=np.tan, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), @@ -1726,6 +1749,18 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): supports_tensor_out=False, test_inplace_grad=False, sample_inputs_func=sample_repeat_tile), + OpInfo('var', + dtypes=floating_types_and(), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + supports_tensor_out=False, + test_complex_grad=False, + test_inplace_grad=False, + # var has only partial support for complex and half (#51127) + skips=(SkipInfo('TestOpInfo', 'test_unsupported_dtypes', + dtypes=[torch.half, torch.complex64, torch.complex128]),), + assert_autodiffed=True, + ), ] if TEST_SCIPY: @@ -1735,6 +1770,29 @@ def reference_sigmoid(x): return (1 / (1 + np.exp(-x))) return scipy.special.expit(x) + def reference_lgamma(x): + # scipy.special.gammaln returns `-inf` when input is `-inf`. + # While Pytorch, C and C++, all return `inf` when input is `-inf`. + # Reference: + # https://en.cppreference.com/w/cpp/numeric/math/lgamma + # https://en.cppreference.com/w/c/numeric/math/lgamma + + # To handle the above discrepancy, + # we replace -inf with inf so values + # that were originally -inf map to inf as expected + if x.dtype.kind == 'f': + x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x) + + out = scipy.special.gammaln(x) + + if x.dtype == np.float16: + # `scipy.special.gammaln` returns output of float32 when input is float16, + # while `torch.lgamma` preserves `float16`. But due to smaller range of float16, + # Pytorch version outputs `inf` while SciPy returns finite values. + out = out.astype(np.float16) + + return out + op_db_scipy_reference: List[OpInfo] = [ UnaryUfuncInfo('sigmoid', ref=reference_sigmoid, @@ -1812,6 +1870,27 @@ def reference_sigmoid(x): dtypes=[torch.bfloat16]), ) ), + UnaryUfuncInfo('lgamma', + ref=reference_lgamma, + decorators=(precisionOverride({torch.float16: 7e-1}),), + dtypes=all_types_and(torch.bool), + dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16), + dtypesIfCUDA=all_types_and(torch.bool, torch.half), + skips=( + # Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.bfloat16]), + # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 + SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', + dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), + # Backward of `lgamma` uses `digamma` but `digamma` + # is not implemented for `BFloat16` + # Error Raised: + # RuntimeError: "digamma" not implemented for 'BFloat16' + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.bfloat16]), + ), + safe_casts_outputs=True), OpInfo('xlogy', dtypes=all_types_and(torch.bool), dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16), @@ -2294,16 +2373,6 @@ def method_tests(): ('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'), ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]), ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]), - ('var', (S, S, S), NO_ARGS, '', (True,)), - ('var', (S, S, S), (1,), 'dim', (True,), [0]), - ('var', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]), - ('var', (S,), (0,), 'dim_1d', (True,), [0]), - ('var', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]), - ('std', (S, S, S), NO_ARGS, '', (True,)), - ('std', (S, S, S), (1,), 'dim', (True,), [0]), - ('std', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]), - ('std', (S,), (0,), 'dim_1d', (True,), [0]), - ('std', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]), ('var_mean', (S, S, S), NO_ARGS, ''), ('var_mean', (S, S, S), (1,), 'dim', [0]), ('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]), diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 15d5cfeca214..54e936bf0f0d 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -2235,3 +2235,34 @@ def test_verify_backend_options(self): self.assertEqual(self.rpc_backend_options.num_send_recv_threads, 8) self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) + +class TensorPipeDistAutogradTest(RpcAgentTestFixture): + + @skip_if_lt_x_gpu(4) + def test_device_maps_backward_pass(self): + options = self.rpc_backend_options + dst = worker_name((self.rank + 1) % self.world_size) + + # The reverse of this device mapping should be used for the backward pass. + options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size}) + + rpc.init_rpc( + name=worker_name(self.rank), + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=options, + ) + + t1 = torch.rand(10, device=self.rank, requires_grad=True) + t2 = torch.rand(10, device=self.rank, requires_grad=True) + with dist_autograd.context() as context_id: + res = rpc.rpc_sync(dst, torch.add, args=(t1, t2)) + dist_autograd.backward(context_id, [res.sum()]) + grads = dist_autograd.get_gradients(context_id) + self.assertEqual(torch.ones(10), grads[t1]) + self.assertEqual(torch.ones(10), grads[t2]) + self.assertEqual(t1.device, grads[t1].device) + self.assertEqual(t2.device, grads[t2].device) + + rpc.shutdown() diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index d35f3da5d2c2..bdf4bbd6eb78 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -23,6 +23,7 @@ from torch.testing._internal.distributed.rpc.dist_autograd_test import ( DistAutogradTest, FaultyAgentDistAutogradTest, + TensorPipeDistAutogradTest ) from torch.testing._internal.distributed.rpc.dist_optimizer_test import ( DistOptimizerTest, @@ -139,7 +140,8 @@ class MultiProcess(Flag): # These suites should be standalone, and separate from the ones in the generic # list (not subclasses of those!). TENSORPIPE_TESTS = [ - TensorPipeAgentRpcTest + TensorPipeAgentRpcTest, + TensorPipeDistAutogradTest ] diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index cd134b38aba9..25ab7d1fc3f5 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -290,14 +290,15 @@ def gen_script_fn_and_args(method_name, func_type, *args, **kwargs): CU = torch.jit.CompilationUnit(script) return CU.the_method, tensors -# create a script function from (name, func_type, output_process_fn), -# returns a function takes in (args, kwargs) and runs the compiled function and -# then applies the post process fn to the outputs -def create_script_fn(self, method_name, func_type, output_process_fn): +# create a script function from (name, func_type), +# returns a function takes in (args, kwargs) and runs the compiled function +def create_script_fn(self, method_name, func_type): + # function returns tuple containing original output and + # filtered output to be used in checking gradients def script_fn(*args, **kwargs): fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs) self.assertExportImport(fn.graph, tensors) - output = output_process_fn(fn(*tensors)) + output = fn(*tensors) # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined] return output